• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Utilities for saving/restoring tensor slice checkpoints.
17 
18 #ifndef TENSORFLOW_CORE_UTIL_SAVED_TENSOR_SLICE_UTIL_H_
19 #define TENSORFLOW_CORE_UTIL_SAVED_TENSOR_SLICE_UTIL_H_
20 
21 #include <string>  // for string
22 #include "tensorflow/core/framework/tensor.pb.h"
23 #include "tensorflow/core/framework/tensor_slice.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/lib/core/status.h"  // for Status
26 #include "tensorflow/core/platform/protobuf.h"
27 
28 namespace tensorflow {
29 
30 namespace checkpoint {
31 
32 // The key for the metadata in the tensor slice checkpoint files. It is "" so
33 // that the metadata is always at the beginning of a checkpoint file.
34 extern const char kSavedTensorSlicesKey[];
35 
36 // Encode a tensor name + a tensor slice into an ordered code and outputs it as
37 // a string.
38 // The format is
39 //  <0>
40 //  <tensor_name>
41 //  <rank>
42 //  <dim-0-start><dim-0-length>
43 //  <dim-1-start><dim-1-length>
44 //  ...
45 
46 string EncodeTensorNameSlice(const string& name,
47                              const tensorflow::TensorSlice& slice);
48 
49 // Parse out the name and the slice from string encoded as an ordered code.
50 Status DecodeTensorNameSlice(const string& code, string* name,
51                              tensorflow::TensorSlice* slice);
52 
53 // Extracts the full shape, slice spec, and shape of the slice from
54 // "shape_and_slice".  On non-OK return, caller must clear the out-arguments
55 // before reusing.
56 Status ParseShapeAndSlice(const string& shape_and_slice, TensorShape* shape,
57                           TensorSlice* slice, TensorShape* shape_slice);
58 
59 template <typename T>
60 struct SaveTypeTraits;
61 
62 template <typename T>
63 const typename SaveTypeTraits<T>::SavedType* TensorProtoData(
64     const TensorProto& t);
65 
66 template <typename T>
67 typename SaveTypeTraits<T>::RepeatedField* MutableTensorProtoData(
68     TensorProto* t);
69 
70 template <typename T>
71 void Fill(T* data, size_t n, TensorProto* t);
72 
73 #define TENSOR_PROTO_EXTRACT_TYPE_HELPER(TYPE, FIELD, FTYPE, STYPE)      \
74   template <>                                                            \
75   struct SaveTypeTraits<TYPE> {                                          \
76     static constexpr bool supported = true;                              \
77     typedef STYPE SavedType;                                             \
78     typedef protobuf::RepeatedField<FTYPE> RepeatedField;                \
79   };                                                                     \
80   template <>                                                            \
81   inline const STYPE* TensorProtoData<TYPE>(const TensorProto& t) {      \
82     static_assert(SaveTypeTraits<TYPE>::supported,                       \
83                   "Specified type " #TYPE " not supported for Restore"); \
84     return reinterpret_cast<const STYPE*>(t.FIELD##_val().data());       \
85   }                                                                      \
86   template <>                                                            \
87   inline protobuf::RepeatedField<FTYPE>* MutableTensorProtoData<TYPE>(   \
88       TensorProto * t) {                                                 \
89     static_assert(SaveTypeTraits<TYPE>::supported,                       \
90                   "Specified type " #TYPE " not supported for Save");    \
91     return reinterpret_cast<protobuf::RepeatedField<FTYPE>*>(            \
92         t->mutable_##FIELD##_val());                                     \
93   }
94 
95 #define TENSOR_PROTO_EXTRACT_TYPE(TYPE, FIELD, FTYPE)             \
96   TENSOR_PROTO_EXTRACT_TYPE_HELPER(TYPE, FIELD, FTYPE, FTYPE)     \
97   template <>                                                     \
98   inline void Fill(const TYPE* data, size_t n, TensorProto* t) {  \
99     typename protobuf::RepeatedField<FTYPE> copy(data, data + n); \
100     t->mutable_##FIELD##_val()->Swap(&copy);                      \
101   }
102 
103 // Complex needs special treatment since proto doesn't have native complex
104 #define TENSOR_PROTO_EXTRACT_TYPE_COMPLEX(TYPE, FIELD, FTYPE)       \
105   TENSOR_PROTO_EXTRACT_TYPE_HELPER(TYPE, FIELD, FTYPE, TYPE)        \
106   template <>                                                       \
107   inline void Fill(const TYPE* data, size_t n, TensorProto* t) {    \
108     const FTYPE* sub = reinterpret_cast<const FTYPE*>(data);        \
109     typename protobuf::RepeatedField<FTYPE> copy(sub, sub + 2 * n); \
110     t->mutable_##FIELD##_val()->Swap(&copy);                        \
111   }
112 
113 TENSOR_PROTO_EXTRACT_TYPE(bool, bool, bool);
114 TENSOR_PROTO_EXTRACT_TYPE(float, float, float);
115 TENSOR_PROTO_EXTRACT_TYPE(double, double, double);
116 TENSOR_PROTO_EXTRACT_TYPE_COMPLEX(complex64, scomplex, float);
117 TENSOR_PROTO_EXTRACT_TYPE_COMPLEX(complex128, dcomplex, double);
118 TENSOR_PROTO_EXTRACT_TYPE(int32, int, int32);
119 TENSOR_PROTO_EXTRACT_TYPE(uint32, uint32, uint32);
120 TENSOR_PROTO_EXTRACT_TYPE(int64, int64, protobuf_int64);
121 TENSOR_PROTO_EXTRACT_TYPE(uint64, uint64, protobuf_uint64);
122 TENSOR_PROTO_EXTRACT_TYPE(uint16, int, int32);
123 TENSOR_PROTO_EXTRACT_TYPE(uint8, int, int32);
124 TENSOR_PROTO_EXTRACT_TYPE(int8, int, int32);
125 TENSOR_PROTO_EXTRACT_TYPE(int16, int, int32);
126 TENSOR_PROTO_EXTRACT_TYPE(qint8, int, int32);
127 TENSOR_PROTO_EXTRACT_TYPE(quint8, int, int32);
128 TENSOR_PROTO_EXTRACT_TYPE(quint16, int, int32);
129 
130 #undef TENSOR_PROTO_EXTRACT_TYPE_COMPLEX
131 #undef TENSOR_PROTO_EXTRACT_TYPE_HELPER
132 #undef TENSOR_PROTO_EXTRACT_TYPE
133 
134 // Custom implementation for qint32, based on the one for int32.
135 
136 template <>
137 struct SaveTypeTraits<qint32> : SaveTypeTraits<int32> {};
138 
139 template <>
140 inline const int32* TensorProtoData<qint32>(const TensorProto& t) {
141   static_assert(SaveTypeTraits<qint32>::supported,
142                 "Specified type qint32 not supported for Restore");
143   return reinterpret_cast<const int32*>(t.int_val().data());
144 }
145 
146 inline void Fill(const qint32* data, size_t n, TensorProto* t) {
147   const int32* p = reinterpret_cast<const int32*>(data);
148   typename protobuf::RepeatedField<int32> copy(p, p + n);
149   t->mutable_int_val()->Swap(&copy);
150 }
151 
152 // Custom implementation for Eigen::half.
153 
154 template <>
155 struct SaveTypeTraits<Eigen::half> {
156   static constexpr bool supported = true;
157   typedef int SavedType;
158   typedef protobuf::RepeatedField<int32> RepeatedField;
159 };
160 
161 template <>
162 inline const int* TensorProtoData<Eigen::half>(const TensorProto& t) {
163   return t.half_val().data();
164 }
165 
166 template <>
167 inline protobuf::RepeatedField<int32>* MutableTensorProtoData<Eigen::half>(
168     TensorProto* t) {
169   return t->mutable_half_val();
170 }
171 
172 template <>
173 inline void Fill(const Eigen::half* data, size_t n, TensorProto* t) {
174   typename protobuf::RepeatedField<int32>* val = t->mutable_half_val();
175   val->Resize(n, 0);
176   for (size_t i = 0; i < n; ++i) {
177     val->Set(i, data[i].x);
178   }
179 }
180 
181 // Custom implementation for string.
182 
183 template <>
184 struct SaveTypeTraits<tstring> {
185   static constexpr bool supported = true;
186   typedef const string* SavedType;
187   typedef protobuf::RepeatedPtrField<string> RepeatedField;
188 };
189 
190 template <>
191 inline const string* const* TensorProtoData<tstring>(const TensorProto& t) {
192   static_assert(SaveTypeTraits<tstring>::supported,
193                 "Specified type tstring not supported for Restore");
194   return t.string_val().data();
195 }
196 
197 template <>
198 inline protobuf::RepeatedPtrField<string>* MutableTensorProtoData<tstring>(
199     TensorProto* t) {
200   static_assert(SaveTypeTraits<tstring>::supported,
201                 "Specified type tstring not supported for Save");
202   return t->mutable_string_val();
203 }
204 
205 template <>
206 inline void Fill(const tstring* data, size_t n, TensorProto* t) {
207   typename protobuf::RepeatedPtrField<string> copy(data, data + n);
208   t->mutable_string_val()->Swap(&copy);
209 }
210 
211 }  // namespace checkpoint
212 
213 }  // namespace tensorflow
214 
215 #endif  // TENSORFLOW_CORE_UTIL_SAVED_TENSOR_SLICE_UTIL_H_
216