1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Utilities for saving/restoring tensor slice checkpoints. 17 18 #ifndef TENSORFLOW_CORE_UTIL_SAVED_TENSOR_SLICE_UTIL_H_ 19 #define TENSORFLOW_CORE_UTIL_SAVED_TENSOR_SLICE_UTIL_H_ 20 21 #include <string> // for string 22 #include "tensorflow/core/framework/tensor.pb.h" 23 #include "tensorflow/core/framework/tensor_slice.h" 24 #include "tensorflow/core/framework/types.h" 25 #include "tensorflow/core/lib/core/status.h" // for Status 26 #include "tensorflow/core/platform/protobuf.h" 27 28 namespace tensorflow { 29 30 namespace checkpoint { 31 32 // The key for the metadata in the tensor slice checkpoint files. It is "" so 33 // that the metadata is always at the beginning of a checkpoint file. 34 extern const char kSavedTensorSlicesKey[]; 35 36 // Encode a tensor name + a tensor slice into an ordered code and outputs it as 37 // a string. 38 // The format is 39 // <0> 40 // <tensor_name> 41 // <rank> 42 // <dim-0-start><dim-0-length> 43 // <dim-1-start><dim-1-length> 44 // ... 45 46 string EncodeTensorNameSlice(const string& name, 47 const tensorflow::TensorSlice& slice); 48 49 // Parse out the name and the slice from string encoded as an ordered code. 50 Status DecodeTensorNameSlice(const string& code, string* name, 51 tensorflow::TensorSlice* slice); 52 53 // Extracts the full shape, slice spec, and shape of the slice from 54 // "shape_and_slice". On non-OK return, caller must clear the out-arguments 55 // before reusing. 56 Status ParseShapeAndSlice(const string& shape_and_slice, TensorShape* shape, 57 TensorSlice* slice, TensorShape* shape_slice); 58 59 template <typename T> 60 struct SaveTypeTraits; 61 62 template <typename T> 63 const typename SaveTypeTraits<T>::SavedType* TensorProtoData( 64 const TensorProto& t); 65 66 template <typename T> 67 typename SaveTypeTraits<T>::RepeatedField* MutableTensorProtoData( 68 TensorProto* t); 69 70 template <typename T> 71 void Fill(T* data, size_t n, TensorProto* t); 72 73 #define TENSOR_PROTO_EXTRACT_TYPE_HELPER(TYPE, FIELD, FTYPE, STYPE) \ 74 template <> \ 75 struct SaveTypeTraits<TYPE> { \ 76 static constexpr bool supported = true; \ 77 typedef STYPE SavedType; \ 78 typedef protobuf::RepeatedField<FTYPE> RepeatedField; \ 79 }; \ 80 template <> \ 81 inline const STYPE* TensorProtoData<TYPE>(const TensorProto& t) { \ 82 static_assert(SaveTypeTraits<TYPE>::supported, \ 83 "Specified type " #TYPE " not supported for Restore"); \ 84 return reinterpret_cast<const STYPE*>(t.FIELD##_val().data()); \ 85 } \ 86 template <> \ 87 inline protobuf::RepeatedField<FTYPE>* MutableTensorProtoData<TYPE>( \ 88 TensorProto * t) { \ 89 static_assert(SaveTypeTraits<TYPE>::supported, \ 90 "Specified type " #TYPE " not supported for Save"); \ 91 return reinterpret_cast<protobuf::RepeatedField<FTYPE>*>( \ 92 t->mutable_##FIELD##_val()); \ 93 } 94 95 #define TENSOR_PROTO_EXTRACT_TYPE(TYPE, FIELD, FTYPE) \ 96 TENSOR_PROTO_EXTRACT_TYPE_HELPER(TYPE, FIELD, FTYPE, FTYPE) \ 97 template <> \ 98 inline void Fill(const TYPE* data, size_t n, TensorProto* t) { \ 99 typename protobuf::RepeatedField<FTYPE> copy(data, data + n); \ 100 t->mutable_##FIELD##_val()->Swap(©); \ 101 } 102 103 // Complex needs special treatment since proto doesn't have native complex 104 #define TENSOR_PROTO_EXTRACT_TYPE_COMPLEX(TYPE, FIELD, FTYPE) \ 105 TENSOR_PROTO_EXTRACT_TYPE_HELPER(TYPE, FIELD, FTYPE, TYPE) \ 106 template <> \ 107 inline void Fill(const TYPE* data, size_t n, TensorProto* t) { \ 108 const FTYPE* sub = reinterpret_cast<const FTYPE*>(data); \ 109 typename protobuf::RepeatedField<FTYPE> copy(sub, sub + 2 * n); \ 110 t->mutable_##FIELD##_val()->Swap(©); \ 111 } 112 113 TENSOR_PROTO_EXTRACT_TYPE(bool, bool, bool); 114 TENSOR_PROTO_EXTRACT_TYPE(float, float, float); 115 TENSOR_PROTO_EXTRACT_TYPE(double, double, double); 116 TENSOR_PROTO_EXTRACT_TYPE_COMPLEX(complex64, scomplex, float); 117 TENSOR_PROTO_EXTRACT_TYPE_COMPLEX(complex128, dcomplex, double); 118 TENSOR_PROTO_EXTRACT_TYPE(int32, int, int32); 119 TENSOR_PROTO_EXTRACT_TYPE(uint32, uint32, uint32); 120 TENSOR_PROTO_EXTRACT_TYPE(int64, int64, protobuf_int64); 121 TENSOR_PROTO_EXTRACT_TYPE(uint64, uint64, protobuf_uint64); 122 TENSOR_PROTO_EXTRACT_TYPE(uint16, int, int32); 123 TENSOR_PROTO_EXTRACT_TYPE(uint8, int, int32); 124 TENSOR_PROTO_EXTRACT_TYPE(int8, int, int32); 125 TENSOR_PROTO_EXTRACT_TYPE(int16, int, int32); 126 TENSOR_PROTO_EXTRACT_TYPE(qint8, int, int32); 127 TENSOR_PROTO_EXTRACT_TYPE(quint8, int, int32); 128 TENSOR_PROTO_EXTRACT_TYPE(quint16, int, int32); 129 130 #undef TENSOR_PROTO_EXTRACT_TYPE_COMPLEX 131 #undef TENSOR_PROTO_EXTRACT_TYPE_HELPER 132 #undef TENSOR_PROTO_EXTRACT_TYPE 133 134 // Custom implementation for qint32, based on the one for int32. 135 136 template <> 137 struct SaveTypeTraits<qint32> : SaveTypeTraits<int32> {}; 138 139 template <> 140 inline const int32* TensorProtoData<qint32>(const TensorProto& t) { 141 static_assert(SaveTypeTraits<qint32>::supported, 142 "Specified type qint32 not supported for Restore"); 143 return reinterpret_cast<const int32*>(t.int_val().data()); 144 } 145 146 inline void Fill(const qint32* data, size_t n, TensorProto* t) { 147 const int32* p = reinterpret_cast<const int32*>(data); 148 typename protobuf::RepeatedField<int32> copy(p, p + n); 149 t->mutable_int_val()->Swap(©); 150 } 151 152 // Custom implementation for Eigen::half. 153 154 template <> 155 struct SaveTypeTraits<Eigen::half> { 156 static constexpr bool supported = true; 157 typedef int SavedType; 158 typedef protobuf::RepeatedField<int32> RepeatedField; 159 }; 160 161 template <> 162 inline const int* TensorProtoData<Eigen::half>(const TensorProto& t) { 163 return t.half_val().data(); 164 } 165 166 template <> 167 inline protobuf::RepeatedField<int32>* MutableTensorProtoData<Eigen::half>( 168 TensorProto* t) { 169 return t->mutable_half_val(); 170 } 171 172 template <> 173 inline void Fill(const Eigen::half* data, size_t n, TensorProto* t) { 174 typename protobuf::RepeatedField<int32>* val = t->mutable_half_val(); 175 val->Resize(n, 0); 176 for (size_t i = 0; i < n; ++i) { 177 val->Set(i, data[i].x); 178 } 179 } 180 181 // Custom implementation for string. 182 183 template <> 184 struct SaveTypeTraits<tstring> { 185 static constexpr bool supported = true; 186 typedef const string* SavedType; 187 typedef protobuf::RepeatedPtrField<string> RepeatedField; 188 }; 189 190 template <> 191 inline const string* const* TensorProtoData<tstring>(const TensorProto& t) { 192 static_assert(SaveTypeTraits<tstring>::supported, 193 "Specified type tstring not supported for Restore"); 194 return t.string_val().data(); 195 } 196 197 template <> 198 inline protobuf::RepeatedPtrField<string>* MutableTensorProtoData<tstring>( 199 TensorProto* t) { 200 static_assert(SaveTypeTraits<tstring>::supported, 201 "Specified type tstring not supported for Save"); 202 return t->mutable_string_val(); 203 } 204 205 template <> 206 inline void Fill(const tstring* data, size_t n, TensorProto* t) { 207 typename protobuf::RepeatedPtrField<string> copy(data, data + n); 208 t->mutable_string_val()->Swap(©); 209 } 210 211 } // namespace checkpoint 212 213 } // namespace tensorflow 214 215 #endif // TENSORFLOW_CORE_UTIL_SAVED_TENSOR_SLICE_UTIL_H_ 216