1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_VARIANT_ENCODE_DECODE_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_VARIANT_ENCODE_DECODE_H_
18
19 #include <iostream>
20 #include <type_traits>
21 #include <utility>
22 #include <vector>
23
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/type_index.h"
26 #include "tensorflow/core/framework/variant_tensor_data.h"
27 #include "tensorflow/core/lib/strings/strcat.h"
28 #include "tensorflow/core/platform/abi.h"
29 #include "tensorflow/core/platform/protobuf.h"
30
31 namespace tensorflow {
32
33 // Type used for tag-dispatch of the Encode/Decode Variant implementations. This
34 // template can determine whether the first type parameter `T` is one of the
35 // following:
36 //
37 // * A POD type (TypeResolver<T, true>)
38 // * A tensorflow::Tensor (TypeResolver<T, false, true>)
39 // * A protocol buffer (TypeResolver<T, false, false, true>)
40 // * None of the above (TypeResolver<T, false, false, false>)
41 //
42 template <typename T, bool = std::is_pod<typename std::decay<T>::type>::value,
43 bool = std::is_same<typename std::decay<T>::type,
44 ::tensorflow::Tensor>::value,
45 bool = std::is_base_of<protobuf::MessageLite,
46 typename std::decay<T>::type>::value>
47 struct TypeResolver {};
48
49 // Specialization for POD type
50 template <typename T>
EncodeVariantImpl(const T & value,TypeResolver<T,true>,VariantTensorData * data)51 void EncodeVariantImpl(const T& value, TypeResolver<T, true /* is_pod */>,
52 VariantTensorData* data) {
53 data->set_metadata(value);
54 }
55
56 // Specialization for tensorflow::Tensor
57 template <typename T>
EncodeVariantImpl(const T & value,TypeResolver<T,false,true>,VariantTensorData * data)58 void EncodeVariantImpl(const T& value,
59 TypeResolver<T, false /* is_pod */, true /* Tensor */>,
60 VariantTensorData* data) {
61 data->tensors_.clear();
62 data->tensors_.push_back(value);
63 }
64
65 // Specialization for protobuf
66 template <typename T>
EncodeVariantImpl(const T & value,TypeResolver<T,false,false,true>,VariantTensorData * data)67 void EncodeVariantImpl(const T& value,
68 TypeResolver<T, false /* is_pod */, false /* Tensor */,
69 true /* protobuf */>,
70 VariantTensorData* data) {
71 value.SerializeToString(&data->metadata_);
72 }
73
74 // Specialization for other types
75 template <typename T>
EncodeVariantImpl(const T & value,TypeResolver<T,false,false,false>,VariantTensorData * data)76 void EncodeVariantImpl(const T& value,
77 TypeResolver<T, false /* is_pod */, false /* Tensor */,
78 false /* protobuf */>,
79 VariantTensorData* data) {
80 value.Encode(data);
81 }
82
83 // Specialization for POD type
84 template <typename T>
DecodeVariantImpl(VariantTensorData data,TypeResolver<T,true,false,false>,T * value)85 bool DecodeVariantImpl(VariantTensorData data,
86 TypeResolver<T, true /* is_pod */, false /* Tensor */,
87 false /* protobuf */>,
88 T* value) {
89 return data.get_metadata(value);
90 }
91
92 // Specialization for tensorflow::Tensor
93 template <typename T>
DecodeVariantImpl(VariantTensorData data,TypeResolver<T,false,true,false>,T * value)94 bool DecodeVariantImpl(VariantTensorData data,
95 TypeResolver<T, false /* is_pod */, true /* Tensor */,
96 false /* protobuf */>,
97 T* value) {
98 *value = data.tensors(0);
99 return true;
100 }
101
102 // Specialization for protobuf
103 template <typename T>
DecodeVariantImpl(VariantTensorData data,TypeResolver<T,false,false,true>,T * value)104 bool DecodeVariantImpl(VariantTensorData data,
105 TypeResolver<T, false /* is_pod */, false /* Tensor */,
106 true /* protobuf */>,
107 T* value) {
108 std::string metadata;
109 data.get_metadata(&metadata);
110 return value->ParseFromString(std::move(metadata));
111 }
112
113 // Specialization for other types
114 template <typename T>
DecodeVariantImpl(VariantTensorData data,TypeResolver<T,false,false,false>,T * value)115 bool DecodeVariantImpl(VariantTensorData data,
116 TypeResolver<T, false /* is_pod */, false /* Tensor */,
117 false /* protobuf */>,
118 T* value) {
119 return value->Decode(std::move(data));
120 }
121
122 template <typename C, typename = void>
123 struct has_type_name : std::false_type {};
124
125 template <typename C>
126 struct has_type_name<
127 C, typename std::enable_if<std::is_same<
128 decltype(std::declval<C>().TypeName()), string>::value>::type>
129 : std::true_type {};
130
131 template <typename T, bool = has_type_name<typename std::decay<T>::type>::value,
132 bool = std::is_same<typename std::decay<T>::type,
133 ::tensorflow::Tensor>::value,
134 bool = std::is_base_of<protobuf::MessageLite,
135 typename std::decay<T>::type>::value>
136 struct TypeNameResolver {};
137
138 template <typename T>
139 std::string TypeNameVariantImpl(const T& value,
140 TypeNameResolver<T, true /* has_type_name */>) {
141 return value.TypeName();
142 }
143
144 template <typename T>
145 std::string TypeNameVariantImpl(
146 const T& value,
147 TypeNameResolver<T, false /* has_type_name */, true /* Tensor */>) {
148 return "tensorflow::Tensor";
149 }
150
151 template <typename T>
152 std::string TypeNameVariantImpl(
153 const T& value, TypeNameResolver<T, false /* has_type_name */,
154 false /* Tensor */, true /* protobuf */>) {
155 return value.GetTypeName();
156 }
157
158 template <typename T>
159 std::string TypeNameVariantImpl(
160 const T& value,
161 TypeNameResolver<T, false /* has_type_name */, false /* Tensor */,
162 false /* protobuf */>) {
163 return port::MaybeAbiDemangle(TypeIndex::Make<T>().name());
164 }
165
166 template <typename T>
167 std::string TypeNameVariant(const T& value) {
168 return TypeNameVariantImpl(value, TypeNameResolver<T>());
169 }
170
171 template <typename C, typename = void>
172 struct has_debug_string : std::false_type {};
173
174 template <typename C>
175 struct has_debug_string<
176 C, typename std::enable_if<std::is_same<
177 decltype(std::declval<C>().DebugString()), string>::value>::type>
178 : std::true_type {};
179
180 template <typename C, typename = void>
181 struct can_strcat : std::false_type {};
182
183 template <typename C>
184 struct can_strcat<
185 C, typename std::enable_if<std::is_same<
186 decltype(strings::StrCat(std::declval<C>())), string>::value>::type>
187 : std::true_type {};
188
189 template <typename T,
190 bool = has_debug_string<typename std::decay<T>::type>::value,
191 bool = can_strcat<typename std::decay<T>::type>::value>
192 struct DebugStringResolver {};
193
194 // TODO(ebrevdo): Expand DebugStringResolver to return TypeString if
195 // there is no StrCat<T>() constructor.
196 template <typename T>
197 std::string DebugStringVariantImpl(
198 const T& value, DebugStringResolver<T, true /* has_debug_string */>) {
199 return value.DebugString();
200 }
201
202 template <typename T>
203 std::string DebugStringVariantImpl(
204 const T& value, DebugStringResolver<T, false /* has_debug_string */,
205 true /* can_strcat */>) {
206 return strings::StrCat(value);
207 }
208
209 template <typename T>
210 std::string DebugStringVariantImpl(
211 const T& value, DebugStringResolver<T, false /* has_debug_string */,
212 false /* can_strcat */>) {
213 return "?";
214 }
215
216 template <typename T>
217 std::string DebugStringVariant(const T& value) {
218 return DebugStringVariantImpl(value, DebugStringResolver<T>());
219 }
220
221 template <typename T>
222 void EncodeVariant(const T& value, VariantTensorData* data) {
223 EncodeVariantImpl(value, TypeResolver<T>(), data);
224 data->set_type_name(TypeNameVariant(value));
225 }
226
227 template <typename T>
228 bool DecodeVariant(VariantTensorData* data, T* value) {
229 return DecodeVariantImpl(std::move(*data), TypeResolver<T>(), value);
230 }
231
232 template <typename T>
233 void EncodeVariant(const T& value, std::string* buf) {
234 VariantTensorData data;
235 EncodeVariantImpl(value, TypeResolver<T>(), &data);
236 data.set_type_name(TypeNameVariant(value));
237 DCHECK(buf != nullptr);
238 data.SerializeToString(buf);
239 }
240
241 template <typename T>
242 bool DecodeVariant(std::string* buf, T* value) {
243 VariantTensorData data;
244 if (!data.ParseFromString(*buf)) return false;
245 if (!DecodeVariantImpl(std::move(data), TypeResolver<T>(), value)) {
246 return false;
247 }
248 return true;
249 }
250
251 // Specializations for VariantTensorDataProto
252 template <>
253 std::string TypeNameVariant(const VariantTensorDataProto& value);
254
255 template <>
256 void EncodeVariant(const VariantTensorDataProto& value,
257 VariantTensorData* data);
258
259 template <>
260 bool DecodeVariant(VariantTensorData* data, VariantTensorDataProto* value);
261
262 template <>
263 void EncodeVariant(const VariantTensorDataProto& value, std::string* buf);
264
265 template <>
266 bool DecodeVariant(std::string* buf, VariantTensorDataProto* value);
267
268 // Encodes an array of Variant objects in to the given StringListEncoder.
269 // `variant_array` is assumed to point to an array of `n` Variant objects.
270 void EncodeVariantList(const Variant* variant_array, int64 n,
271 std::unique_ptr<port::StringListEncoder> e);
272
273 // Decodes an array of Variant objects from the given StringListDecoder.
274 // `variant_array` is assumed to point to an array of `n` Variant objects.
275 bool DecodeVariantList(std::unique_ptr<port::StringListDecoder> d,
276 Variant* variant_array, int64 n);
277
278 } // end namespace tensorflow
279
280 #endif // TENSORFLOW_CORE_FRAMEWORK_VARIANT_ENCODE_DECODE_H_
281