• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Utilities for saving/restoring tensor slice checkpoints.
17 
18 #ifndef TENSORFLOW_CORE_UTIL_SAVED_TENSOR_SLICE_UTIL_H_
19 #define TENSORFLOW_CORE_UTIL_SAVED_TENSOR_SLICE_UTIL_H_
20 
21 #include <string>  // for string
22 #include "tensorflow/core/framework/tensor.pb.h"
23 #include "tensorflow/core/framework/tensor_slice.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/lib/core/status.h"  // for Status
26 #include "tensorflow/core/platform/protobuf.h"
27 
28 namespace tensorflow {
29 
30 namespace checkpoint {
31 
32 // The key for the metadata in the tensor slice checkpoint files. It is "" so
33 // that the metadata is always at the beginning of a checkpoint file.
34 extern const char kSavedTensorSlicesKey[];
35 
36 // Encode a tensor name + a tensor slice into an ordered code and outputs it as
37 // a string.
38 // The format is
39 //  <0>
40 //  <tensor_name>
41 //  <rank>
42 //  <dim-0-start><dim-0-length>
43 //  <dim-1-start><dim-1-length>
44 //  ...
45 
46 string EncodeTensorNameSlice(const string& name,
47                              const tensorflow::TensorSlice& slice);
48 
49 // Parse out the name and the slice from string encoded as an ordered code.
50 Status DecodeTensorNameSlice(const string& code, string* name,
51                              tensorflow::TensorSlice* slice);
52 
53 // Extracts the full shape, slice spec, and shape of the slice from
54 // "shape_and_slice".  On non-OK return, caller must clear the out-arguments
55 // before reusing.
56 Status ParseShapeAndSlice(const string& shape_and_slice, TensorShape* shape,
57                           TensorSlice* slice, TensorShape* shape_slice);
58 
59 template <typename T>
60 struct SaveTypeTraits;
61 
62 template <typename T>
63 const typename SaveTypeTraits<T>::SavedType* TensorProtoData(
64     const TensorProto& t);
65 
66 template <typename T>
67 typename SaveTypeTraits<T>::RepeatedField* MutableTensorProtoData(
68     TensorProto* t);
69 
70 template <typename T>
71 void Fill(T* data, size_t n, TensorProto* t);
72 
73 #define TENSOR_PROTO_EXTRACT_TYPE_HELPER(TYPE, FIELD, FTYPE, STYPE)      \
74   template <>                                                            \
75   struct SaveTypeTraits<TYPE> {                                          \
76     static constexpr bool supported = true;                              \
77     typedef STYPE SavedType;                                             \
78     typedef protobuf::RepeatedField<FTYPE> RepeatedField;                \
79   };                                                                     \
80   template <>                                                            \
81   inline const STYPE* TensorProtoData<TYPE>(const TensorProto& t) {      \
82     static_assert(SaveTypeTraits<TYPE>::supported,                       \
83                   "Specified type " #TYPE " not supported for Restore"); \
84     return reinterpret_cast<const STYPE*>(t.FIELD##_val().data());       \
85   }                                                                      \
86   template <>                                                            \
87   inline protobuf::RepeatedField<FTYPE>* MutableTensorProtoData<TYPE>(   \
88       TensorProto * t) {                                                 \
89     static_assert(SaveTypeTraits<TYPE>::supported,                       \
90                   "Specified type " #TYPE " not supported for Save");    \
91     return reinterpret_cast<protobuf::RepeatedField<FTYPE>*>(            \
92         t->mutable_##FIELD##_val());                                     \
93   }
94 
95 #define TENSOR_PROTO_EXTRACT_TYPE(TYPE, FIELD, FTYPE)             \
96   TENSOR_PROTO_EXTRACT_TYPE_HELPER(TYPE, FIELD, FTYPE, FTYPE)     \
97   template <>                                                     \
98   inline void Fill(const TYPE* data, size_t n, TensorProto* t) {  \
99     typename protobuf::RepeatedField<FTYPE> copy(data, data + n); \
100     t->mutable_##FIELD##_val()->Swap(&copy);                      \
101   }
102 
103 // Complex needs special treatment since proto doesn't have native complex
104 #define TENSOR_PROTO_EXTRACT_TYPE_COMPLEX(TYPE, FIELD, FTYPE)       \
105   TENSOR_PROTO_EXTRACT_TYPE_HELPER(TYPE, FIELD, FTYPE, TYPE)        \
106   template <>                                                       \
107   inline void Fill(const TYPE* data, size_t n, TensorProto* t) {    \
108     const FTYPE* sub = reinterpret_cast<const FTYPE*>(data);        \
109     typename protobuf::RepeatedField<FTYPE> copy(sub, sub + 2 * n); \
110     t->mutable_##FIELD##_val()->Swap(&copy);                        \
111   }
112 
113 TENSOR_PROTO_EXTRACT_TYPE(bool, bool, bool);
114 TENSOR_PROTO_EXTRACT_TYPE(float, float, float);
115 TENSOR_PROTO_EXTRACT_TYPE(double, double, double);
116 TENSOR_PROTO_EXTRACT_TYPE_COMPLEX(complex64, scomplex, float);
117 TENSOR_PROTO_EXTRACT_TYPE_COMPLEX(complex128, dcomplex, double);
118 TENSOR_PROTO_EXTRACT_TYPE(int32, int, int32);
119 TENSOR_PROTO_EXTRACT_TYPE(int64, int64, protobuf_int64);
120 TENSOR_PROTO_EXTRACT_TYPE(uint16, int, int32);
121 TENSOR_PROTO_EXTRACT_TYPE(uint8, int, int32);
122 TENSOR_PROTO_EXTRACT_TYPE(int8, int, int32);
123 TENSOR_PROTO_EXTRACT_TYPE(int16, int, int32);
124 TENSOR_PROTO_EXTRACT_TYPE(qint8, int, int32);
125 TENSOR_PROTO_EXTRACT_TYPE(quint8, int, int32);
126 TENSOR_PROTO_EXTRACT_TYPE(quint16, int, int32);
127 
128 #undef TENSOR_PROTO_EXTRACT_TYPE_COMPLEX
129 #undef TENSOR_PROTO_EXTRACT_TYPE_HELPER
130 #undef TENSOR_PROTO_EXTRACT_TYPE
131 
132 // Custom implementation for qint32, based on the one for int32.
133 
134 template <>
135 struct SaveTypeTraits<qint32> : SaveTypeTraits<int32> {};
136 
137 template <>
138 inline const int32* TensorProtoData<qint32>(const TensorProto& t) {
139   static_assert(SaveTypeTraits<qint32>::supported,
140                 "Specified type qint32 not supported for Restore");
141   return reinterpret_cast<const int32*>(t.int_val().data());
142 }
143 
144 inline void Fill(const qint32* data, size_t n, TensorProto* t) {
145   const int32* p = reinterpret_cast<const int32*>(data);
146   typename protobuf::RepeatedField<int32> copy(p, p + n);
147   t->mutable_int_val()->Swap(&copy);
148 }
149 
150 // Custom implementation for Eigen::half.
151 
152 template <>
153 struct SaveTypeTraits<Eigen::half> {
154   static constexpr bool supported = true;
155   typedef int SavedType;
156   typedef protobuf::RepeatedField<int32> RepeatedField;
157 };
158 
159 template <>
160 inline const int* TensorProtoData<Eigen::half>(const TensorProto& t) {
161   return t.half_val().data();
162 }
163 
164 template <>
165 inline protobuf::RepeatedField<int32>* MutableTensorProtoData<Eigen::half>(
166     TensorProto* t) {
167   return t->mutable_half_val();
168 }
169 
170 template <>
171 inline void Fill(const Eigen::half* data, size_t n, TensorProto* t) {
172   typename protobuf::RepeatedField<int32>* val = t->mutable_half_val();
173   val->Resize(n, 0);
174   for (size_t i = 0; i < n; ++i) {
175     val->Set(i, data[i].x);
176   }
177 }
178 
179 // Custom implementation for string.
180 
181 template <>
182 struct SaveTypeTraits<string> {
183   static constexpr bool supported = true;
184   typedef const string* SavedType;
185   typedef protobuf::RepeatedPtrField<string> RepeatedField;
186 };
187 
188 template <>
189 inline const string* const* TensorProtoData<string>(const TensorProto& t) {
190   static_assert(SaveTypeTraits<string>::supported,
191                 "Specified type string not supported for Restore");
192   return t.string_val().data();
193 }
194 
195 template <>
196 inline protobuf::RepeatedPtrField<string>* MutableTensorProtoData<string>(
197     TensorProto* t) {
198   static_assert(SaveTypeTraits<string>::supported,
199                 "Specified type string not supported for Save");
200   return t->mutable_string_val();
201 }
202 
203 template <>
204 inline void Fill(const string* data, size_t n, TensorProto* t) {
205   typename protobuf::RepeatedPtrField<string> copy(data, data + n);
206   t->mutable_string_val()->Swap(&copy);
207 }
208 
209 }  // namespace checkpoint
210 
211 }  // namespace tensorflow
212 
213 #endif  // TENSORFLOW_CORE_UTIL_SAVED_TENSOR_SLICE_UTIL_H_
214