1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_ 17 #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_ 18 19 #include <algorithm> 20 #include <vector> 21 22 #include "tensorflow/core/framework/tensor.h" 23 #include "tensorflow/core/framework/tensor.pb.h" 24 #include "tensorflow/core/framework/tensor_shape.pb.h" 25 #include "tensorflow/core/framework/type_traits.h" 26 #include "tensorflow/core/platform/protobuf.h" 27 #include "tensorflow/core/platform/types.h" 28 29 namespace tensorflow { 30 namespace tensor { 31 32 // DeepCopy returns a tensor whose contents are a deep copy of the 33 // contents of 'other'. This function is intended only for 34 // convenience, not speed. 35 // 36 // REQUIRES: 'other' must point to data stored in CPU memory. 37 // REQUIRES: 'other' must be a Tensor of a copy-able type if 38 // 'other' is not appropriately memory-aligned. 39 Tensor DeepCopy(const Tensor& other); 40 41 // Concatenates 'tensors' into a single tensor, along their 0th dimension. 42 // 43 // REQUIRES: All members of 'tensors' must have the same data type parameter. 44 // REQUIRES: Each member of 'tensors' must have at least one dimension. 45 // REQUIRES: Each member of 'tensors' must point to data stored in CPU memory. 46 // REQUIRES: Each member of 'tensors' must be a Tensor of a copy-able type if it 47 // is not appropriately memory-aligned. 48 Status Concat(const gtl::ArraySlice<Tensor>& tensors, 49 Tensor* result) TF_MUST_USE_RESULT; 50 51 // Splits 'tensor' into 'sizes.size()' individual tensors, along the 0th 52 // dimension. The ith output tensor has 0th-dimension size 'sizes[i]'. 53 // 54 // REQUIRES: 'tensor' must have at least one dimension. 55 // REQUIRES: 'tensor.dim_size(0)' must equal the sum of the elements of 'sizes'. 56 // REQUIRES: 'tensor' must point to data stored in CPU memory. 57 // REQUIRES: 'tensor' must be a Tensor of a copy-able type if it is not 58 // appropriately memory-aligned. 59 // 60 // Split() and Concat() are inverse operations. 61 Status Split(const Tensor& tensor, const gtl::ArraySlice<int64>& sizes, 62 std::vector<Tensor>* result) TF_MUST_USE_RESULT; 63 64 namespace internal { 65 void SetTensorProtoShape(std::vector<size_t> shape, 66 TensorShapeProto* shape_proto); 67 68 template <typename Type> 69 class TensorProtoFieldHelper : public std::false_type {}; 70 71 #define DEFINE_PROTO_FIELD_HELPER(TYPE, FIELDNAME) \ 72 template <> \ 73 class TensorProtoFieldHelper<TYPE> : public std::true_type { \ 74 public: \ 75 typedef decltype( \ 76 std::declval<TensorProto>().FIELDNAME##_val(0)) FieldType; \ 77 typedef decltype( \ 78 std::declval<TensorProto>().FIELDNAME##_val()) RepeatedFieldType; \ 79 typedef decltype(std::declval<TensorProto>().mutable_##FIELDNAME##_val()) \ 80 MutableRepeatedFieldType; \ 81 static MutableRepeatedFieldType GetMutableField(TensorProto* proto) { \ 82 return proto->mutable_##FIELDNAME##_val(); \ 83 } \ 84 static RepeatedFieldType& GetField(const TensorProto& proto) { \ 85 return proto.FIELDNAME##_val(); \ 86 } \ 87 } 88 89 // The argument pairs in the following macro instantiations encode the 90 // mapping from C++ type ($1) to repeated field name "$2_val" used for storing 91 // values in TensorProto. See tensorflow/core/framework/tensor.proto. 92 DEFINE_PROTO_FIELD_HELPER(float, float); 93 DEFINE_PROTO_FIELD_HELPER(double, double); 94 DEFINE_PROTO_FIELD_HELPER(int8, int); 95 DEFINE_PROTO_FIELD_HELPER(uint8, int); 96 DEFINE_PROTO_FIELD_HELPER(int16, int); 97 DEFINE_PROTO_FIELD_HELPER(uint16, int); 98 DEFINE_PROTO_FIELD_HELPER(int32, int); 99 DEFINE_PROTO_FIELD_HELPER(uint32, uint32); 100 DEFINE_PROTO_FIELD_HELPER(int64, int64); 101 DEFINE_PROTO_FIELD_HELPER(uint64, uint64); 102 DEFINE_PROTO_FIELD_HELPER(bool, bool); 103 DEFINE_PROTO_FIELD_HELPER(qint8, int); 104 DEFINE_PROTO_FIELD_HELPER(quint8, int); 105 DEFINE_PROTO_FIELD_HELPER(qint16, int); 106 DEFINE_PROTO_FIELD_HELPER(quint16, int); 107 DEFINE_PROTO_FIELD_HELPER(qint32, int); 108 DEFINE_PROTO_FIELD_HELPER(Eigen::half, half); 109 DEFINE_PROTO_FIELD_HELPER(bfloat16, half); 110 DEFINE_PROTO_FIELD_HELPER(complex64, scomplex); 111 DEFINE_PROTO_FIELD_HELPER(complex128, dcomplex); 112 113 #undef DEFINE_PROTO_HELPER 114 115 template <typename T> 116 struct CopyHelper { 117 template <typename SrcIter, typename DstIter> ToArrayCopyHelper118 static void ToArray(SrcIter begin, SrcIter end, DstIter dst) { 119 using SrcType = typename std::iterator_traits<SrcIter>::value_type; 120 using DstType = typename std::iterator_traits<DstIter>::value_type; 121 std::transform(begin, end, dst, [](const SrcType& x) -> DstType { 122 return static_cast<DstType>(x); 123 }); 124 } 125 template <typename SrcIter> ToArrayCopyHelper126 static void ToArray(SrcIter begin, SrcIter end, SrcIter dst) { 127 std::copy(begin, end, dst); 128 } 129 template <typename SrcIter, typename DstIter> FromArrayCopyHelper130 static void FromArray(SrcIter begin, SrcIter end, DstIter dst) { 131 ToArray(begin, end, dst); 132 } 133 }; 134 135 // Overloads for Eigen::half and bfloat16 that are 16 bits in size but are 136 // stored in an int32 field. 137 template <> 138 struct CopyHelper<Eigen::half> { 139 template <typename SrcIter> 140 static void ToArray(SrcIter begin, SrcIter end, Eigen::half* dst) { 141 std::transform(begin, end, dst, [](int x) -> Eigen::half { 142 Eigen::half h; 143 h.x = static_cast<uint16>(x); 144 return h; 145 }); 146 } 147 template <typename SrcIter, typename DstIter> 148 static void FromArray(SrcIter begin, SrcIter end, DstIter dst) { 149 std::transform(begin, end, dst, 150 [](Eigen::half h) -> int { return static_cast<int>(h.x); }); 151 } 152 }; 153 154 template <> 155 struct CopyHelper<bfloat16> { 156 template <typename SrcIter> 157 static void ToArray(SrcIter begin, SrcIter end, bfloat16* dst) { 158 std::transform(begin, end, dst, [](int x) -> bfloat16 { 159 bfloat16 bf16; 160 bf16.value = static_cast<uint16>(x); 161 return bf16; 162 }); 163 } 164 template <typename SrcIter, typename DstIter> 165 static void FromArray(SrcIter begin, SrcIter end, DstIter dst) { 166 std::transform(begin, end, dst, [](bfloat16 bf16) -> int { 167 return static_cast<int>(bf16.value); 168 }); 169 } 170 }; 171 172 // Overloads for complex types that store real and imaginary parts 173 // at indices 2*i and 2*i+1 in float or double field. 174 template <typename RealType> 175 struct CopyHelper<std::complex<RealType>> { 176 template <typename SrcIter> 177 static void ToArray(SrcIter begin, SrcIter end, std::complex<RealType>* dst) { 178 using SrcType = typename std::iterator_traits<SrcIter>::value_type; 179 RealType* real_dst = reinterpret_cast<RealType*>(dst); 180 std::copy(begin, end, real_dst); 181 } 182 183 template <typename SrcIter, typename DstIter> 184 static void FromArray(SrcIter begin, SrcIter end, DstIter dst) { 185 using DstType = typename std::iterator_traits<DstIter>::value_type; 186 size_t n = std::distance(begin, end); 187 const RealType* real_begin = reinterpret_cast<const RealType*>(&(*begin)); 188 std::copy_n(real_begin, 2 * n, dst); 189 } 190 }; 191 192 // Helper class to extract and insert values into TensorProto represented as 193 // repeated fields. 194 template <typename T> 195 class TensorProtoHelper : public std::true_type { 196 public: 197 using FieldHelper = TensorProtoFieldHelper<T>; 198 using FieldType = typename TensorProtoFieldHelper<T>::FieldType; 199 200 static DataType GetDataType() { return DataTypeToEnum<T>::value; } 201 202 // Returns the number of values of type T encoded in the proto. 203 static size_t NumValues(const TensorProto& proto) { 204 size_t raw_size = FieldHelper::GetField(proto).size(); 205 return is_complex<T>::value ? raw_size / 2 : raw_size; 206 } 207 208 static void AddValue(const T& value, TensorProto* proto) { 209 const T* val_ptr = &value; 210 AddValues(val_ptr, val_ptr + 1, proto); 211 } 212 213 static T GetValue(size_t index, const TensorProto& proto) { 214 T val; 215 if (is_complex<T>::value) index *= 2; 216 CopyHelper<T>::ToArray(FieldHelper::GetField(proto).begin() + index, 217 FieldHelper::GetField(proto).begin() + index + 1, 218 &val); 219 return val; 220 } 221 222 template <typename IterType> 223 static void AddValues(IterType begin, IterType end, TensorProto* proto) { 224 size_t n = std::distance(begin, end); 225 FieldType* dst = AppendUninitialized(n, proto); 226 CopyHelper<T>::FromArray(begin, end, dst); 227 } 228 229 template <typename IterType> 230 static void CopyValues(IterType dst, const TensorProto& proto) { 231 CopyHelper<T>::ToArray(FieldHelper::GetField(proto).begin(), 232 FieldHelper::GetField(proto).end(), dst); 233 } 234 235 static void Truncate(size_t new_size, TensorProto* proto) { 236 if (is_complex<T>::value) new_size *= 2; 237 FieldHelper::GetMutableField(proto)->Truncate(new_size); 238 } 239 240 static FieldType* AppendUninitialized(size_t n, TensorProto* proto) { 241 if (is_complex<T>::value) n *= 2; 242 auto* field = FieldHelper::GetMutableField(proto); 243 field->Reserve(field->size() + n); 244 return reinterpret_cast<FieldType*>(field->AddNAlreadyReserved(n)); 245 } 246 }; 247 248 // Specialization for string. 249 template <> 250 class TensorProtoHelper<string> : public std::true_type { 251 public: 252 static DataType GetDataType() { return DataType::DT_STRING; } 253 static void AddValue(const string& value, TensorProto* proto) { 254 *proto->mutable_string_val()->Add() = value; 255 } 256 template <typename IterType> 257 static void AddValues(IterType begin, IterType end, TensorProto* proto) { 258 for (IterType it = begin; it != end; ++it) { 259 AddValue(*it, proto); 260 } 261 } 262 template <typename IterType> 263 static void CopyToTensorContent(IterType begin, IterType end, 264 TensorProto* proto) { 265 AddValues(begin, end, proto); 266 } 267 }; 268 269 } // namespace internal 270 271 // Creates a 'TensorProto' with specified shape and values. 272 // The dtype and a field to represent data values of the returned 'TensorProto' 273 // are determined based on type of the 'values' parameter. 274 template <typename Type> 275 typename std::enable_if<internal::TensorProtoHelper<Type>::value, 276 TensorProto>::type 277 CreateTensorProto(const std::vector<Type>& values, 278 const std::vector<size_t>& shape) { 279 TensorProto tensor; 280 TensorShapeProto tensor_shape_proto; 281 internal::SetTensorProtoShape(shape, &tensor_shape_proto); 282 if (TensorShape(tensor_shape_proto).num_elements() != values.size()) { 283 LOG(ERROR) << "Shape and number of values (" << values.size() 284 << ") are incompatible."; 285 return tensor; 286 } 287 using TypeHelper = internal::TensorProtoHelper<Type>; 288 tensor.set_dtype(TypeHelper::GetDataType()); 289 tensor.mutable_tensor_shape()->Swap(&tensor_shape_proto); 290 TypeHelper::AddValues(values.begin(), values.end(), &tensor); 291 return tensor; 292 } 293 294 // Converts values in tensor to run-length encoded compressed form. 295 // 296 // The elements of a tensor can be stored in a TensorProto in one of the 297 // following two forms: 298 // 1. As a raw byte string in the field `tensor_content` containing the 299 // serialized in-memory representation of the tensor. 300 // 2. As values of a repeated field depending on the datatype, e.g. that 301 // values of a DT_FLOAT tensor would be stored in the repeated field 302 // `float_val`. 303 // Storage scheme 2 may use a simple form of run-length encoding to compress 304 // data: If the values contains a tail of identical values, the repeated field 305 // will be truncated such that the number of values in the repeated field is 306 // less than the number of elements implied by the field`tensor_shape`. The 307 // original tensor can be recovered by repeating the final value in the repeated 308 // field. 309 // 310 // The TensorProto will be compressed if a) the tensor contains at least 311 // min_num_elements elements and b) the compressed tensor proto is would be at 312 // most the size of the original tensor proto divided by min_compression_ratio. 313 // 314 // Returns true if the tensor was compressed. 315 bool CompressTensorProtoInPlace(int64 min_num_elements, 316 float min_compression_ratio, 317 TensorProto* tensor); 318 319 inline bool CompressTensorProtoInPlace(TensorProto* tensor) { 320 static const int64 kDefaultMinNumElements = 64; 321 static const float kDefaultMinCompressionRatio = 2.0f; 322 return CompressTensorProtoInPlace(kDefaultMinNumElements, 323 kDefaultMinCompressionRatio, tensor); 324 } 325 326 } // namespace tensor 327 } // namespace tensorflow 328 329 #endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_ 330