1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <string>
17
18 #include "tensorflow/core/framework/register_types.h"
19 #include "tensorflow/core/framework/type_index.h"
20 #include "tensorflow/core/framework/variant.h"
21 #include "tensorflow/core/framework/variant_op_registry.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/public/version.h"
24
25 namespace tensorflow {
26
PersistentStringStorage()27 std::unordered_set<string>* UnaryVariantOpRegistry::PersistentStringStorage() {
28 static std::unordered_set<string>* string_storage =
29 new std::unordered_set<string>();
30 return string_storage;
31 }
32
33 // static
Global()34 UnaryVariantOpRegistry* UnaryVariantOpRegistry::Global() {
35 static UnaryVariantOpRegistry* global_unary_variant_op_registry =
36 new UnaryVariantOpRegistry;
37 return global_unary_variant_op_registry;
38 }
39
GetDecodeFn(StringPiece type_name)40 UnaryVariantOpRegistry::VariantDecodeFn* UnaryVariantOpRegistry::GetDecodeFn(
41 StringPiece type_name) {
42 auto found = decode_fns.find(type_name);
43 if (found == decode_fns.end()) return nullptr;
44 return &found->second;
45 }
46
RegisterDecodeFn(const string & type_name,const VariantDecodeFn & decode_fn)47 void UnaryVariantOpRegistry::RegisterDecodeFn(
48 const string& type_name, const VariantDecodeFn& decode_fn) {
49 CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantDecode";
50 VariantDecodeFn* existing = GetDecodeFn(type_name);
51 CHECK_EQ(existing, nullptr)
52 << "Unary VariantDecodeFn for type_name: " << type_name
53 << " already registered";
54 decode_fns.insert(std::pair<StringPiece, VariantDecodeFn>(
55 GetPersistentStringPiece(type_name), decode_fn));
56 }
57
DecodeUnaryVariant(Variant * variant)58 bool DecodeUnaryVariant(Variant* variant) {
59 UnaryVariantOpRegistry::VariantDecodeFn* decode_fn =
60 UnaryVariantOpRegistry::Global()->GetDecodeFn(variant->TypeName());
61 if (decode_fn == nullptr) {
62 return false;
63 }
64 const string type_name = variant->TypeName();
65 bool decoded = (*decode_fn)(variant);
66 if (!decoded) return false;
67 if (variant->TypeName() != type_name) {
68 LOG(ERROR) << "DecodeUnaryVariant: Variant type_name before decoding was: "
69 << type_name
70 << " but after decoding was: " << variant->TypeName()
71 << ". Treating this as a failure.";
72 return false;
73 }
74 return true;
75 }
76
77 // Add some basic registrations for use by others, e.g., for testing.
78
79 #define REGISTER_VARIANT_DECODE_TYPE(T) \
80 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(T, TF_STR(T));
81
82 // No encode/decode registered for std::complex<> and Eigen::half
83 // objects yet.
84 REGISTER_VARIANT_DECODE_TYPE(int);
85 REGISTER_VARIANT_DECODE_TYPE(float);
86 REGISTER_VARIANT_DECODE_TYPE(bool);
87 REGISTER_VARIANT_DECODE_TYPE(double);
88
89 #undef REGISTER_VARIANT_DECODE_TYPE
90
91 UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn*
GetDeviceCopyFn(const VariantDeviceCopyDirection direction,const TypeIndex & type_index)92 UnaryVariantOpRegistry::GetDeviceCopyFn(
93 const VariantDeviceCopyDirection direction, const TypeIndex& type_index) {
94 auto found = device_copy_fns.find(std::make_pair(direction, type_index));
95 if (found == device_copy_fns.end()) return nullptr;
96 return &found->second;
97 }
98
RegisterDeviceCopyFn(const VariantDeviceCopyDirection direction,const TypeIndex & type_index,const AsyncVariantDeviceCopyFn & device_copy_fn)99 void UnaryVariantOpRegistry::RegisterDeviceCopyFn(
100 const VariantDeviceCopyDirection direction, const TypeIndex& type_index,
101 const AsyncVariantDeviceCopyFn& device_copy_fn) {
102 AsyncVariantDeviceCopyFn* existing = GetDeviceCopyFn(direction, type_index);
103 CHECK_EQ(existing, nullptr)
104 << "UnaryVariantDeviceCopy for direction: " << direction
105 << " and type_index: " << port::MaybeAbiDemangle(type_index.name())
106 << " already registered";
107 device_copy_fns.insert(
108 std::pair<std::pair<VariantDeviceCopyDirection, TypeIndex>,
109 AsyncVariantDeviceCopyFn>(std::make_pair(direction, type_index),
110 device_copy_fn));
111 }
112
VariantDeviceCopy(const VariantDeviceCopyDirection direction,const Variant & from,Variant * to,const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn & copy_fn)113 Status VariantDeviceCopy(
114 const VariantDeviceCopyDirection direction, const Variant& from,
115 Variant* to,
116 const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn) {
117 UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn* device_copy_fn =
118 UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(direction,
119 from.TypeId());
120 if (device_copy_fn == nullptr) {
121 return errors::Internal(
122 "No unary variant device copy function found for direction: ",
123 direction, " and Variant type_index: ",
124 port::MaybeAbiDemangle(from.TypeId().name()));
125 }
126 return (*device_copy_fn)(from, to, copy_fn);
127 }
128
129 namespace {
130 template <typename T>
DeviceCopyPrimitiveType(const T & in,T * out,const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn & copier)131 Status DeviceCopyPrimitiveType(
132 const T& in, T* out,
133 const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copier) {
134 // Dummy copy, we don't actually bother copying to the device and back for
135 // testing.
136 *out = in;
137 return Status::OK();
138 }
139 } // namespace
140
141 #define REGISTER_VARIANT_DEVICE_COPY_TYPE(T) \
142 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
143 T, VariantDeviceCopyDirection::HOST_TO_DEVICE, \
144 DeviceCopyPrimitiveType<T>); \
145 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
146 T, VariantDeviceCopyDirection::DEVICE_TO_HOST, \
147 DeviceCopyPrimitiveType<T>); \
148 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
149 T, VariantDeviceCopyDirection::DEVICE_TO_DEVICE, \
150 DeviceCopyPrimitiveType<T>);
151
152 // No zeros_like registered for std::complex<> or Eigen::half objects yet.
153 REGISTER_VARIANT_DEVICE_COPY_TYPE(int);
154 REGISTER_VARIANT_DEVICE_COPY_TYPE(float);
155 REGISTER_VARIANT_DEVICE_COPY_TYPE(double);
156 REGISTER_VARIANT_DEVICE_COPY_TYPE(bool);
157
158 #undef REGISTER_VARIANT_DEVICE_COPY_TYPE
159
160 // Special casing UnaryOpFn per op and per device.
GetUnaryOpFn(VariantUnaryOp op,StringPiece device,const TypeIndex & type_index)161 UnaryVariantOpRegistry::VariantUnaryOpFn* UnaryVariantOpRegistry::GetUnaryOpFn(
162 VariantUnaryOp op, StringPiece device, const TypeIndex& type_index) {
163 auto found = unary_op_fns.find({op, device, type_index});
164 if (found == unary_op_fns.end()) return nullptr;
165 return &found->second;
166 }
167
RegisterUnaryOpFn(VariantUnaryOp op,const string & device,const TypeIndex & type_index,const VariantUnaryOpFn & unary_op_fn)168 void UnaryVariantOpRegistry::RegisterUnaryOpFn(
169 VariantUnaryOp op, const string& device, const TypeIndex& type_index,
170 const VariantUnaryOpFn& unary_op_fn) {
171 VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_index);
172 CHECK_EQ(existing, nullptr)
173 << "Unary VariantUnaryOpFn for type_index: "
174 << port::MaybeAbiDemangle(type_index.name())
175 << " already registered for device type: " << device;
176 unary_op_fns.insert(std::pair<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn>(
177 {op, GetPersistentStringPiece(device), type_index}, unary_op_fn));
178 }
179
180 namespace {
181 template <typename T>
ZerosLikeVariantPrimitiveType(OpKernelContext * ctx,const T & t,T * t_out)182 Status ZerosLikeVariantPrimitiveType(OpKernelContext* ctx, const T& t,
183 T* t_out) {
184 *t_out = T(0);
185 return Status::OK();
186 }
187 } // namespace
188
189 #define REGISTER_VARIANT_ZEROS_LIKE_TYPE(T) \
190 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, \
191 DEVICE_CPU, T, \
192 ZerosLikeVariantPrimitiveType<T>);
193
194 // No zeros_like registered for std::complex<> or Eigen::half objects yet.
195 REGISTER_VARIANT_ZEROS_LIKE_TYPE(int);
196 REGISTER_VARIANT_ZEROS_LIKE_TYPE(float);
197 REGISTER_VARIANT_ZEROS_LIKE_TYPE(double);
198 REGISTER_VARIANT_ZEROS_LIKE_TYPE(bool);
199
200 #undef REGISTER_VARIANT_ZEROS_LIKE_TYPE
201
202 // Special casing BinaryOpFn per op and per device.
203 UnaryVariantOpRegistry::VariantBinaryOpFn*
GetBinaryOpFn(VariantBinaryOp op,StringPiece device,const TypeIndex & type_index)204 UnaryVariantOpRegistry::GetBinaryOpFn(VariantBinaryOp op, StringPiece device,
205 const TypeIndex& type_index) {
206 auto found = binary_op_fns.find({op, device, type_index});
207 if (found == binary_op_fns.end()) return nullptr;
208 return &found->second;
209 }
210
RegisterBinaryOpFn(VariantBinaryOp op,const string & device,const TypeIndex & type_index,const VariantBinaryOpFn & add_fn)211 void UnaryVariantOpRegistry::RegisterBinaryOpFn(
212 VariantBinaryOp op, const string& device, const TypeIndex& type_index,
213 const VariantBinaryOpFn& add_fn) {
214 VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_index);
215 CHECK_EQ(existing, nullptr)
216 << "Unary VariantBinaryOpFn for type_index: "
217 << port::MaybeAbiDemangle(type_index.name())
218 << " already registered for device type: " << device;
219 binary_op_fns.insert(std::pair<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn>(
220 {op, GetPersistentStringPiece(device), type_index}, add_fn));
221 }
222
223 namespace {
224 template <typename T>
AddVariantPrimitiveType(OpKernelContext * ctx,const T & a,const T & b,T * out)225 Status AddVariantPrimitiveType(OpKernelContext* ctx, const T& a, const T& b,
226 T* out) {
227 *out = a + b;
228 return Status::OK();
229 }
230 } // namespace
231
232 #define REGISTER_VARIANT_ADD_TYPE(T) \
233 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, \
234 T, AddVariantPrimitiveType<T>);
235
236 // No add registered for std::complex<> or Eigen::half objects yet.
237 REGISTER_VARIANT_ADD_TYPE(int);
238 REGISTER_VARIANT_ADD_TYPE(float);
239 REGISTER_VARIANT_ADD_TYPE(double);
240 REGISTER_VARIANT_ADD_TYPE(bool);
241
242 #undef REGISTER_VARIANT_ADD_TYPE
243
244 } // namespace tensorflow
245