• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_VARIANT_ENCODE_DECODE_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_VARIANT_ENCODE_DECODE_H_
18 
19 #include <iostream>
20 #include <type_traits>
21 #include <utility>
22 #include <vector>
23 
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/type_index.h"
26 #include "tensorflow/core/framework/variant_tensor_data.h"
27 #include "tensorflow/core/lib/strings/strcat.h"
28 #include "tensorflow/core/platform/abi.h"
29 #include "tensorflow/core/platform/protobuf.h"
30 
31 namespace tensorflow {
32 
33 // Type used for tag-dispatch of the Encode/Decode Variant implementations. This
34 // template can determine whether the first type parameter `T` is one of the
35 // following:
36 //
37 // * A POD type (TypeResolver<T, true>)
38 // * A tensorflow::Tensor (TypeResolver<T, false, true>)
39 // * A protocol buffer (TypeResolver<T, false, false, true>)
40 // * None of the above (TypeResolver<T, false, false, false>)
41 //
42 template <typename T, bool = std::is_pod<typename std::decay<T>::type>::value,
43           bool = std::is_same<typename std::decay<T>::type,
44                               ::tensorflow::Tensor>::value,
45           bool = std::is_base_of<protobuf::MessageLite,
46                                  typename std::decay<T>::type>::value>
47 struct TypeResolver {};
48 
49 // Specialization for POD type
50 template <typename T>
EncodeVariantImpl(const T & value,TypeResolver<T,true>,VariantTensorData * data)51 void EncodeVariantImpl(const T& value, TypeResolver<T, true /* is_pod */>,
52                        VariantTensorData* data) {
53   data->set_metadata(value);
54 }
55 
56 // Specialization for tensorflow::Tensor
57 template <typename T>
EncodeVariantImpl(const T & value,TypeResolver<T,false,true>,VariantTensorData * data)58 void EncodeVariantImpl(const T& value,
59                        TypeResolver<T, false /* is_pod */, true /* Tensor */>,
60                        VariantTensorData* data) {
61   data->tensors_.clear();
62   data->tensors_.push_back(value);
63 }
64 
65 // Specialization for protobuf
66 template <typename T>
EncodeVariantImpl(const T & value,TypeResolver<T,false,false,true>,VariantTensorData * data)67 void EncodeVariantImpl(const T& value,
68                        TypeResolver<T, false /* is_pod */, false /* Tensor */,
69                                     true /* protobuf */>,
70                        VariantTensorData* data) {
71   value.SerializeToString(&data->metadata_);
72 }
73 
74 // Specialization for other types
75 template <typename T>
EncodeVariantImpl(const T & value,TypeResolver<T,false,false,false>,VariantTensorData * data)76 void EncodeVariantImpl(const T& value,
77                        TypeResolver<T, false /* is_pod */, false /* Tensor */,
78                                     false /* protobuf */>,
79                        VariantTensorData* data) {
80   value.Encode(data);
81 }
82 
83 // Specialization for POD type
84 template <typename T>
DecodeVariantImpl(VariantTensorData data,TypeResolver<T,true,false,false>,T * value)85 bool DecodeVariantImpl(VariantTensorData data,
86                        TypeResolver<T, true /* is_pod */, false /* Tensor */,
87                                     false /* protobuf */>,
88                        T* value) {
89   return data.get_metadata(value);
90 }
91 
92 // Specialization for tensorflow::Tensor
93 template <typename T>
DecodeVariantImpl(VariantTensorData data,TypeResolver<T,false,true,false>,T * value)94 bool DecodeVariantImpl(VariantTensorData data,
95                        TypeResolver<T, false /* is_pod */, true /* Tensor */,
96                                     false /* protobuf */>,
97                        T* value) {
98   *value = data.tensors(0);
99   return true;
100 }
101 
102 // Specialization for protobuf
103 template <typename T>
DecodeVariantImpl(VariantTensorData data,TypeResolver<T,false,false,true>,T * value)104 bool DecodeVariantImpl(VariantTensorData data,
105                        TypeResolver<T, false /* is_pod */, false /* Tensor */,
106                                     true /* protobuf */>,
107                        T* value) {
108   string metadata;
109   data.get_metadata(&metadata);
110   return value->ParseFromString(std::move(metadata));
111 }
112 
113 // Specialization for other types
114 template <typename T>
DecodeVariantImpl(VariantTensorData data,TypeResolver<T,false,false,false>,T * value)115 bool DecodeVariantImpl(VariantTensorData data,
116                        TypeResolver<T, false /* is_pod */, false /* Tensor */,
117                                     false /* protobuf */>,
118                        T* value) {
119   return value->Decode(std::move(data));
120 }
121 
122 template <typename C, typename = void>
123 struct has_type_name : std::false_type {};
124 
125 template <typename C>
126 struct has_type_name<
127     C, typename std::enable_if<std::is_same<
128            decltype(std::declval<C>().TypeName()), string>::value>::type>
129     : std::true_type {};
130 
131 template <typename T, bool = has_type_name<typename std::decay<T>::type>::value,
132           bool = std::is_same<typename std::decay<T>::type,
133                               ::tensorflow::Tensor>::value,
134           bool = std::is_base_of<protobuf::MessageLite,
135                                  typename std::decay<T>::type>::value>
136 struct TypeNameResolver {};
137 
138 template <typename T>
139 string TypeNameVariantImpl(const T& value,
140                            TypeNameResolver<T, true /* has_type_name */>) {
141   return value.TypeName();
142 }
143 
144 template <typename T>
145 string TypeNameVariantImpl(
146     const T& value,
147     TypeNameResolver<T, false /* has_type_name */, true /* Tensor */>) {
148   return "tensorflow::Tensor";
149 }
150 
151 template <typename T>
152 string TypeNameVariantImpl(
153     const T& value, TypeNameResolver<T, false /* has_type_name */,
154                                      false /* Tensor */, true /* protobuf */>) {
155   return value.GetTypeName();
156 }
157 
158 template <typename T>
159 string TypeNameVariantImpl(
160     const T& value,
161     TypeNameResolver<T, false /* has_type_name */, false /* Tensor */,
162                      false /* protobuf */>) {
163   return port::MaybeAbiDemangle(MakeTypeIndex<T>().name());
164 }
165 
166 template <typename T>
167 string TypeNameVariant(const T& value) {
168   return TypeNameVariantImpl(value, TypeNameResolver<T>());
169 }
170 
171 template <typename C, typename = void>
172 struct has_debug_string : std::false_type {};
173 
174 template <typename C>
175 struct has_debug_string<
176     C, typename std::enable_if<std::is_same<
177            decltype(std::declval<C>().DebugString()), string>::value>::type>
178     : std::true_type {};
179 
180 template <typename C, typename = void>
181 struct can_strcat : std::false_type {};
182 
183 template <typename C>
184 struct can_strcat<
185     C, typename std::enable_if<std::is_same<
186            decltype(strings::StrCat(std::declval<C>())), string>::value>::type>
187     : std::true_type {};
188 
189 template <typename T,
190           bool = has_debug_string<typename std::decay<T>::type>::value,
191           bool = can_strcat<typename std::decay<T>::type>::value>
192 struct DebugStringResolver {};
193 
194 // TODO(ebrevdo): Expand DebugStringResolver to return TypeString if
195 // there is no StrCat<T>() constructor.
196 template <typename T>
197 string DebugStringVariantImpl(
198     const T& value, DebugStringResolver<T, true /* has_debug_string */>) {
199   return value.DebugString();
200 }
201 
202 template <typename T>
203 string DebugStringVariantImpl(
204     const T& value, DebugStringResolver<T, false /* has_debug_string */,
205                                         true /* can_strcat */>) {
206   return strings::StrCat(value);
207 }
208 
209 template <typename T>
210 string DebugStringVariantImpl(
211     const T& value, DebugStringResolver<T, false /* has_debug_string */,
212                                         false /* can_strcat */>) {
213   return "?";
214 }
215 
216 template <typename T>
217 string DebugStringVariant(const T& value) {
218   return DebugStringVariantImpl(value, DebugStringResolver<T>());
219 }
220 
221 template <typename T>
222 void EncodeVariant(const T& value, VariantTensorData* data) {
223   EncodeVariantImpl(value, TypeResolver<T>(), data);
224   data->set_type_name(TypeNameVariant(value));
225 }
226 
227 template <typename T>
228 bool DecodeVariant(VariantTensorData* data, T* value) {
229   return DecodeVariantImpl(std::move(*data), TypeResolver<T>(), value);
230 }
231 
232 template <typename T>
233 void EncodeVariant(const T& value, string* buf) {
234   VariantTensorData data;
235   EncodeVariantImpl(value, TypeResolver<T>(), &data);
236   data.set_type_name(TypeNameVariant(value));
237   DCHECK(buf != nullptr);
238   data.SerializeToString(buf);
239 }
240 
241 template <typename T>
242 bool DecodeVariant(string* buf, T* value) {
243   VariantTensorData data;
244   if (!data.ParseFromString(*buf)) return false;
245   if (!DecodeVariantImpl(std::move(data), TypeResolver<T>(), value)) {
246     return false;
247   }
248   return true;
249 }
250 
251 // Specializations for VariantTensorDataProto
252 template <>
253 string TypeNameVariant(const VariantTensorDataProto& value);
254 
255 template <>
256 void EncodeVariant(const VariantTensorDataProto& value,
257                    VariantTensorData* data);
258 
259 template <>
260 bool DecodeVariant(VariantTensorData* data, VariantTensorDataProto* value);
261 
262 template <>
263 void EncodeVariant(const VariantTensorDataProto& value, string* buf);
264 
265 template <>
266 bool DecodeVariant(string* buf, VariantTensorDataProto* value);
267 
268 // Encodes an array of Variant objects in to the given StringListEncoder.
269 // `variant_array` is assumed to point to an array of `n` Variant objects.
270 void EncodeVariantList(const Variant* variant_array, int64 n,
271                        std::unique_ptr<port::StringListEncoder> e);
272 
273 // Decodes an array of Variant objects from the given StringListDecoder.
274 // `variant_array` is assumed to point to an array of `n` Variant objects.
275 bool DecodeVariantList(std::unique_ptr<port::StringListDecoder> d,
276                        Variant* variant_array, int64 n);
277 
278 }  // end namespace tensorflow
279 
280 #endif  // TENSORFLOW_CORE_FRAMEWORK_VARIANT_ENCODE_DECODE_H_
281