1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/variant_op_registry.h"
17
18 #include <string>
19
20 #include "tensorflow/core/framework/register_types.h"
21 #include "tensorflow/core/framework/type_index.h"
22 #include "tensorflow/core/framework/variant.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/public/version.h"
26
27 namespace tensorflow {
28
VariantUnaryOpToString(VariantUnaryOp op)29 const char* VariantUnaryOpToString(VariantUnaryOp op) {
30 switch (op) {
31 case INVALID_VARIANT_UNARY_OP:
32 return "INVALID";
33 case ZEROS_LIKE_VARIANT_UNARY_OP:
34 return "ZEROS_LIKE";
35 case CONJ_VARIANT_UNARY_OP:
36 return "CONJ";
37 }
38 }
39
VariantBinaryOpToString(VariantBinaryOp op)40 const char* VariantBinaryOpToString(VariantBinaryOp op) {
41 switch (op) {
42 case INVALID_VARIANT_BINARY_OP:
43 return "INVALID";
44 case ADD_VARIANT_BINARY_OP:
45 return "ADD";
46 }
47 }
48
PersistentStringStorage()49 std::unordered_set<string>* UnaryVariantOpRegistry::PersistentStringStorage() {
50 static std::unordered_set<string>* string_storage =
51 new std::unordered_set<string>();
52 return string_storage;
53 }
54
GetDecodeFn(StringPiece type_name)55 UnaryVariantOpRegistry::VariantDecodeFn* UnaryVariantOpRegistry::GetDecodeFn(
56 StringPiece type_name) {
57 auto found = decode_fns.find(type_name);
58 if (found == decode_fns.end()) return nullptr;
59 return &found->second;
60 }
61
RegisterDecodeFn(const string & type_name,const VariantDecodeFn & decode_fn)62 void UnaryVariantOpRegistry::RegisterDecodeFn(
63 const string& type_name, const VariantDecodeFn& decode_fn) {
64 CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantDecode";
65 VariantDecodeFn* existing = GetDecodeFn(type_name);
66 CHECK_EQ(existing, nullptr)
67 << "Unary VariantDecodeFn for type_name: " << type_name
68 << " already registered";
69 decode_fns.insert(std::pair<StringPiece, VariantDecodeFn>(
70 GetPersistentStringPiece(type_name), decode_fn));
71 }
72
DecodeUnaryVariant(Variant * variant)73 bool DecodeUnaryVariant(Variant* variant) {
74 CHECK_NOTNULL(variant);
75 if (variant->TypeName().empty()) {
76 VariantTensorDataProto* t = variant->get<VariantTensorDataProto>();
77 if (t == nullptr || !t->metadata().empty() || !t->tensors().empty()) {
78 // Malformed variant.
79 return false;
80 } else {
81 // Serialization of an empty Variant.
82 variant->clear();
83 return true;
84 }
85 }
86 UnaryVariantOpRegistry::VariantDecodeFn* decode_fn =
87 UnaryVariantOpRegistry::Global()->GetDecodeFn(variant->TypeName());
88 if (decode_fn == nullptr) {
89 return false;
90 }
91 const string type_name = variant->TypeName();
92 bool decoded = (*decode_fn)(variant);
93 if (!decoded) return false;
94 if (variant->TypeName() != type_name) {
95 LOG(ERROR) << "DecodeUnaryVariant: Variant type_name before decoding was: "
96 << type_name
97 << " but after decoding was: " << variant->TypeName()
98 << ". Treating this as a failure.";
99 return false;
100 }
101 return true;
102 }
103
104 // Add some basic registrations for use by others, e.g., for testing.
105
106 #define REGISTER_VARIANT_DECODE_TYPE(T) \
107 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(T, TF_STR(T));
108
109 // No encode/decode registered for std::complex<> and Eigen::half
110 // objects yet.
111 REGISTER_VARIANT_DECODE_TYPE(int);
112 REGISTER_VARIANT_DECODE_TYPE(float);
113 REGISTER_VARIANT_DECODE_TYPE(bool);
114 REGISTER_VARIANT_DECODE_TYPE(double);
115
116 #undef REGISTER_VARIANT_DECODE_TYPE
117
VariantDeviceCopy(const VariantDeviceCopyDirection direction,const Variant & from,Variant * to,const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn & copy_fn)118 Status VariantDeviceCopy(
119 const VariantDeviceCopyDirection direction, const Variant& from,
120 Variant* to,
121 const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn) {
122 UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn* device_copy_fn =
123 UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(direction,
124 from.TypeId());
125 if (device_copy_fn == nullptr) {
126 return errors::Internal(
127 "No unary variant device copy function found for direction: ",
128 direction, " and Variant type_index: ",
129 port::MaybeAbiDemangle(from.TypeId().name()));
130 }
131 return (*device_copy_fn)(from, to, copy_fn);
132 }
133
134 namespace {
135 template <typename T>
DeviceCopyPrimitiveType(const T & in,T * out,const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn & copier)136 Status DeviceCopyPrimitiveType(
137 const T& in, T* out,
138 const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copier) {
139 // Dummy copy, we don't actually bother copying to the device and back for
140 // testing.
141 *out = in;
142 return Status::OK();
143 }
144 } // namespace
145
146 #define REGISTER_VARIANT_DEVICE_COPY_TYPE(T) \
147 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
148 T, VariantDeviceCopyDirection::HOST_TO_DEVICE, \
149 DeviceCopyPrimitiveType<T>); \
150 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
151 T, VariantDeviceCopyDirection::DEVICE_TO_HOST, \
152 DeviceCopyPrimitiveType<T>); \
153 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
154 T, VariantDeviceCopyDirection::DEVICE_TO_DEVICE, \
155 DeviceCopyPrimitiveType<T>);
156
157 // No zeros_like registered for std::complex<> or Eigen::half objects yet.
158 REGISTER_VARIANT_DEVICE_COPY_TYPE(int);
159 REGISTER_VARIANT_DEVICE_COPY_TYPE(float);
160 REGISTER_VARIANT_DEVICE_COPY_TYPE(double);
161 REGISTER_VARIANT_DEVICE_COPY_TYPE(bool);
162
163 #undef REGISTER_VARIANT_DEVICE_COPY_TYPE
164
165 namespace {
166 template <typename T>
ZerosLikeVariantPrimitiveType(OpKernelContext * ctx,const T & t,T * t_out)167 Status ZerosLikeVariantPrimitiveType(OpKernelContext* ctx, const T& t,
168 T* t_out) {
169 *t_out = T(0);
170 return Status::OK();
171 }
172 } // namespace
173
174 #define REGISTER_VARIANT_ZEROS_LIKE_TYPE(T) \
175 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, \
176 DEVICE_CPU, T, \
177 ZerosLikeVariantPrimitiveType<T>);
178
179 // No zeros_like registered for std::complex<> or Eigen::half objects yet.
180 REGISTER_VARIANT_ZEROS_LIKE_TYPE(int);
181 REGISTER_VARIANT_ZEROS_LIKE_TYPE(float);
182 REGISTER_VARIANT_ZEROS_LIKE_TYPE(double);
183 REGISTER_VARIANT_ZEROS_LIKE_TYPE(bool);
184
185 #undef REGISTER_VARIANT_ZEROS_LIKE_TYPE
186
187 namespace {
188 template <typename T>
AddVariantPrimitiveType(OpKernelContext * ctx,const T & a,const T & b,T * out)189 Status AddVariantPrimitiveType(OpKernelContext* ctx, const T& a, const T& b,
190 T* out) {
191 *out = a + b;
192 return Status::OK();
193 }
194 } // namespace
195
196 #define REGISTER_VARIANT_ADD_TYPE(T) \
197 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, \
198 T, AddVariantPrimitiveType<T>);
199
200 // No add registered for std::complex<> or Eigen::half objects yet.
201 REGISTER_VARIANT_ADD_TYPE(int);
202 REGISTER_VARIANT_ADD_TYPE(float);
203 REGISTER_VARIANT_ADD_TYPE(double);
204 REGISTER_VARIANT_ADD_TYPE(bool);
205
206 #undef REGISTER_VARIANT_ADD_TYPE
207
208 } // namespace tensorflow
209