• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/variant_op_registry.h"
17 
18 #include <string>
19 
20 #include "tensorflow/core/framework/register_types.h"
21 #include "tensorflow/core/framework/type_index.h"
22 #include "tensorflow/core/framework/variant.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/public/version.h"
26 
27 namespace tensorflow {
28 
VariantUnaryOpToString(VariantUnaryOp op)29 const char* VariantUnaryOpToString(VariantUnaryOp op) {
30   switch (op) {
31     case INVALID_VARIANT_UNARY_OP:
32       return "INVALID";
33     case ZEROS_LIKE_VARIANT_UNARY_OP:
34       return "ZEROS_LIKE";
35     case CONJ_VARIANT_UNARY_OP:
36       return "CONJ";
37   }
38 }
39 
VariantBinaryOpToString(VariantBinaryOp op)40 const char* VariantBinaryOpToString(VariantBinaryOp op) {
41   switch (op) {
42     case INVALID_VARIANT_BINARY_OP:
43       return "INVALID";
44     case ADD_VARIANT_BINARY_OP:
45       return "ADD";
46   }
47 }
48 
PersistentStringStorage()49 std::unordered_set<string>* UnaryVariantOpRegistry::PersistentStringStorage() {
50   static std::unordered_set<string>* string_storage =
51       new std::unordered_set<string>();
52   return string_storage;
53 }
54 
GetDecodeFn(StringPiece type_name)55 UnaryVariantOpRegistry::VariantDecodeFn* UnaryVariantOpRegistry::GetDecodeFn(
56     StringPiece type_name) {
57   auto found = decode_fns.find(type_name);
58   if (found == decode_fns.end()) return nullptr;
59   return &found->second;
60 }
61 
RegisterDecodeFn(const string & type_name,const VariantDecodeFn & decode_fn)62 void UnaryVariantOpRegistry::RegisterDecodeFn(
63     const string& type_name, const VariantDecodeFn& decode_fn) {
64   CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantDecode";
65   VariantDecodeFn* existing = GetDecodeFn(type_name);
66   CHECK_EQ(existing, nullptr)
67       << "Unary VariantDecodeFn for type_name: " << type_name
68       << " already registered";
69   decode_fns.insert(std::pair<StringPiece, VariantDecodeFn>(
70       GetPersistentStringPiece(type_name), decode_fn));
71 }
72 
DecodeUnaryVariant(Variant * variant)73 bool DecodeUnaryVariant(Variant* variant) {
74   CHECK_NOTNULL(variant);
75   if (variant->TypeName().empty()) {
76     VariantTensorDataProto* t = variant->get<VariantTensorDataProto>();
77     if (t == nullptr || !t->metadata().empty() || !t->tensors().empty()) {
78       // Malformed variant.
79       return false;
80     } else {
81       // Serialization of an empty Variant.
82       variant->clear();
83       return true;
84     }
85   }
86   UnaryVariantOpRegistry::VariantDecodeFn* decode_fn =
87       UnaryVariantOpRegistry::Global()->GetDecodeFn(variant->TypeName());
88   if (decode_fn == nullptr) {
89     return false;
90   }
91   const string type_name = variant->TypeName();
92   bool decoded = (*decode_fn)(variant);
93   if (!decoded) return false;
94   if (variant->TypeName() != type_name) {
95     LOG(ERROR) << "DecodeUnaryVariant: Variant type_name before decoding was: "
96                << type_name
97                << " but after decoding was: " << variant->TypeName()
98                << ".  Treating this as a failure.";
99     return false;
100   }
101   return true;
102 }
103 
104 // Add some basic registrations for use by others, e.g., for testing.
105 
106 #define REGISTER_VARIANT_DECODE_TYPE(T) \
107   REGISTER_UNARY_VARIANT_DECODE_FUNCTION(T, TF_STR(T));
108 
109 // No encode/decode registered for std::complex<> and Eigen::half
110 // objects yet.
111 REGISTER_VARIANT_DECODE_TYPE(int);
112 REGISTER_VARIANT_DECODE_TYPE(float);
113 REGISTER_VARIANT_DECODE_TYPE(bool);
114 REGISTER_VARIANT_DECODE_TYPE(double);
115 
116 #undef REGISTER_VARIANT_DECODE_TYPE
117 
VariantDeviceCopy(const VariantDeviceCopyDirection direction,const Variant & from,Variant * to,const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn & copy_fn)118 Status VariantDeviceCopy(
119     const VariantDeviceCopyDirection direction, const Variant& from,
120     Variant* to,
121     const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn) {
122   UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn* device_copy_fn =
123       UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(direction,
124                                                         from.TypeId());
125   if (device_copy_fn == nullptr) {
126     return errors::Internal(
127         "No unary variant device copy function found for direction: ",
128         direction, " and Variant type_index: ",
129         port::MaybeAbiDemangle(from.TypeId().name()));
130   }
131   return (*device_copy_fn)(from, to, copy_fn);
132 }
133 
134 namespace {
135 template <typename T>
DeviceCopyPrimitiveType(const T & in,T * out,const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn & copier)136 Status DeviceCopyPrimitiveType(
137     const T& in, T* out,
138     const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copier) {
139   // Dummy copy, we don't actually bother copying to the device and back for
140   // testing.
141   *out = in;
142   return Status::OK();
143 }
144 }  // namespace
145 
146 #define REGISTER_VARIANT_DEVICE_COPY_TYPE(T)            \
147   INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
148       T, VariantDeviceCopyDirection::HOST_TO_DEVICE,    \
149       DeviceCopyPrimitiveType<T>);                      \
150   INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
151       T, VariantDeviceCopyDirection::DEVICE_TO_HOST,    \
152       DeviceCopyPrimitiveType<T>);                      \
153   INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
154       T, VariantDeviceCopyDirection::DEVICE_TO_DEVICE,  \
155       DeviceCopyPrimitiveType<T>);
156 
157 // No zeros_like registered for std::complex<> or Eigen::half objects yet.
158 REGISTER_VARIANT_DEVICE_COPY_TYPE(int);
159 REGISTER_VARIANT_DEVICE_COPY_TYPE(float);
160 REGISTER_VARIANT_DEVICE_COPY_TYPE(double);
161 REGISTER_VARIANT_DEVICE_COPY_TYPE(bool);
162 
163 #undef REGISTER_VARIANT_DEVICE_COPY_TYPE
164 
165 namespace {
166 template <typename T>
ZerosLikeVariantPrimitiveType(OpKernelContext * ctx,const T & t,T * t_out)167 Status ZerosLikeVariantPrimitiveType(OpKernelContext* ctx, const T& t,
168                                      T* t_out) {
169   *t_out = T(0);
170   return Status::OK();
171 }
172 }  // namespace
173 
174 #define REGISTER_VARIANT_ZEROS_LIKE_TYPE(T)                             \
175   REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, \
176                                            DEVICE_CPU, T,               \
177                                            ZerosLikeVariantPrimitiveType<T>);
178 
179 // No zeros_like registered for std::complex<> or Eigen::half objects yet.
180 REGISTER_VARIANT_ZEROS_LIKE_TYPE(int);
181 REGISTER_VARIANT_ZEROS_LIKE_TYPE(float);
182 REGISTER_VARIANT_ZEROS_LIKE_TYPE(double);
183 REGISTER_VARIANT_ZEROS_LIKE_TYPE(bool);
184 
185 #undef REGISTER_VARIANT_ZEROS_LIKE_TYPE
186 
187 namespace {
188 template <typename T>
AddVariantPrimitiveType(OpKernelContext * ctx,const T & a,const T & b,T * out)189 Status AddVariantPrimitiveType(OpKernelContext* ctx, const T& a, const T& b,
190                                T* out) {
191   *out = a + b;
192   return Status::OK();
193 }
194 }  // namespace
195 
196 #define REGISTER_VARIANT_ADD_TYPE(T)                                           \
197   REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, \
198                                             T, AddVariantPrimitiveType<T>);
199 
200 // No add registered for std::complex<> or Eigen::half objects yet.
201 REGISTER_VARIANT_ADD_TYPE(int);
202 REGISTER_VARIANT_ADD_TYPE(float);
203 REGISTER_VARIANT_ADD_TYPE(double);
204 REGISTER_VARIANT_ADD_TYPE(bool);
205 
206 #undef REGISTER_VARIANT_ADD_TYPE
207 
208 }  // namespace tensorflow
209