1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_VARIANT_OP_REGISTRY_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_VARIANT_OP_REGISTRY_H_
18
19 #include <string>
20 #include <unordered_set>
21 #include <vector>
22
23 #define EIGEN_USE_THREADS
24
25 #include "tensorflow/core/framework/tensor.pb.h"
26 #include "tensorflow/core/framework/type_index.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/framework/variant.h"
29 #include "tensorflow/core/framework/variant_encode_decode.h"
30 #include "tensorflow/core/lib/gtl/flatmap.h"
31 #include "tensorflow/core/lib/hash/hash.h"
32 #include "tensorflow/core/platform/abi.h"
33
34 namespace tensorflow {
35
36 class OpKernelContext;
37 // A global UnaryVariantOpRegistry is used to hold callback functions
38 // for different variant types. To be used by ShapeOp, RankOp, and
39 // SizeOp, decoding, etc.
40
41 enum VariantUnaryOp {
42 INVALID_VARIANT_UNARY_OP = 0,
43 ZEROS_LIKE_VARIANT_UNARY_OP = 1,
44 CONJ_VARIANT_UNARY_OP = 2,
45 };
46
47 const char* VariantUnaryOpToString(VariantUnaryOp op);
48
49 enum VariantBinaryOp {
50 INVALID_VARIANT_BINARY_OP = 0,
51 ADD_VARIANT_BINARY_OP = 1,
52 };
53
54 const char* VariantBinaryOpToString(VariantBinaryOp op);
55
56 enum VariantDeviceCopyDirection {
57 INVALID_DEVICE_COPY_DIRECTION = 0,
58 HOST_TO_DEVICE = 1,
59 DEVICE_TO_HOST = 2,
60 DEVICE_TO_DEVICE = 3,
61 };
62
63 class UnaryVariantOpRegistry {
64 public:
65 typedef std::function<bool(Variant*)> VariantDecodeFn;
66 typedef std::function<Status(OpKernelContext*, const Variant&, Variant*)>
67 VariantUnaryOpFn;
68 typedef std::function<Status(OpKernelContext*, const Variant&, const Variant&,
69 Variant*)>
70 VariantBinaryOpFn;
71
72 // An AsyncTensorDeviceCopyFn is a function provided to
73 // the user-provided DeviceCopyFn callback as the third argument ("copier").
74 //
75 // Expected inputs:
76 // from: A Tensor on the host (if performing cpu->gpu copy), or
77 // device (if performing gpu->cpu or gpu->gpu copy).
78 // to: An empty/uninitialized tensor. It will be updated upon
79 // successful return of the function with the correct dtype and shape.
80 // However, the copied data will not be available until the compute
81 // stream has been synchronized.
82 //
83 // Returns:
84 // The status upon memory allocation / initialization of the
85 // "to" tensor, and enqueue of the copy onto the compute stream.
86 // Any failure of the copy itself will update the underlying
87 // stream status and propagate through the runtime independent
88 // of the caller.
89 typedef std::function<Status(const Tensor& from, Tensor* to)>
90 AsyncTensorDeviceCopyFn;
91
92 // The AsyncVariantDeviceCopyFn is the signature of the 'device_copy_fn'
93 // expected to be passed to the registration macro
94 // INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION.
95 typedef std::function<Status(const Variant& from, Variant* to,
96 AsyncTensorDeviceCopyFn copy_fn)>
97 AsyncVariantDeviceCopyFn;
98
99 // Add a decode function to the registry.
100 void RegisterDecodeFn(const std::string& type_name,
101 const VariantDecodeFn& decode_fn);
102
103 // Returns nullptr if no decode function was found for the given TypeName.
104 VariantDecodeFn* GetDecodeFn(StringPiece type_name);
105
106 // Add a copy-to-GPU function to the registry.
RegisterDeviceCopyFn(const VariantDeviceCopyDirection direction,const TypeIndex & type_index,const AsyncVariantDeviceCopyFn & device_copy_fn)107 void RegisterDeviceCopyFn(const VariantDeviceCopyDirection direction,
108 const TypeIndex& type_index,
109 const AsyncVariantDeviceCopyFn& device_copy_fn) {
110 AsyncVariantDeviceCopyFn* existing = GetDeviceCopyFn(direction, type_index);
111 CHECK_EQ(existing, nullptr)
112 << "UnaryVariantDeviceCopy for direction: " << direction
113 << " and type_index: " << port::MaybeAbiDemangle(type_index.name())
114 << " already registered";
115 device_copy_fns.insert(
116 std::pair<std::pair<VariantDeviceCopyDirection, TypeIndex>,
117 AsyncVariantDeviceCopyFn>(
118 std::make_pair(direction, type_index), device_copy_fn));
119 }
120
121 // Returns nullptr if no copy function was found for the given
122 // TypeName and direction.
GetDeviceCopyFn(const VariantDeviceCopyDirection direction,const TypeIndex & type_index)123 AsyncVariantDeviceCopyFn* GetDeviceCopyFn(
124 const VariantDeviceCopyDirection direction, const TypeIndex& type_index) {
125 auto found = device_copy_fns.find(std::make_pair(direction, type_index));
126 if (found == device_copy_fns.end()) return nullptr;
127 return &found->second;
128 }
129
130 // Add a unary op function to the registry.
RegisterUnaryOpFn(VariantUnaryOp op,const std::string & device,const TypeIndex & type_index,const VariantUnaryOpFn & unary_op_fn)131 void RegisterUnaryOpFn(VariantUnaryOp op, const std::string& device,
132 const TypeIndex& type_index,
133 const VariantUnaryOpFn& unary_op_fn) {
134 VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_index);
135 CHECK_EQ(existing, nullptr)
136 << "Unary VariantUnaryOpFn for type_index: "
137 << port::MaybeAbiDemangle(type_index.name())
138 << " already registered for device type: " << device;
139 unary_op_fns.insert(std::pair<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn>(
140 {op, GetPersistentStringPiece(device), type_index}, unary_op_fn));
141 }
142
143 // Returns nullptr if no unary op function was found for the given
144 // op, device, and TypeName.
GetUnaryOpFn(VariantUnaryOp op,StringPiece device,const TypeIndex & type_index)145 VariantUnaryOpFn* GetUnaryOpFn(VariantUnaryOp op, StringPiece device,
146 const TypeIndex& type_index) {
147 auto found = unary_op_fns.find({op, device, type_index});
148 if (found == unary_op_fns.end()) return nullptr;
149 return &found->second;
150 }
151
152 // Add a binary op function to the registry.
RegisterBinaryOpFn(VariantBinaryOp op,const std::string & device,const TypeIndex & type_index,const VariantBinaryOpFn & add_fn)153 void RegisterBinaryOpFn(VariantBinaryOp op, const std::string& device,
154 const TypeIndex& type_index,
155 const VariantBinaryOpFn& add_fn) {
156 VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_index);
157 CHECK_EQ(existing, nullptr)
158 << "Unary VariantBinaryOpFn for type_index: "
159 << port::MaybeAbiDemangle(type_index.name())
160 << " already registered for device type: " << device;
161 binary_op_fns.insert(
162 std::pair<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn>(
163 {op, GetPersistentStringPiece(device), type_index}, add_fn));
164 }
165
166 // Returns nullptr if no binary op function was found for the given
167 // op, device and TypeName.
GetBinaryOpFn(VariantBinaryOp op,StringPiece device,const TypeIndex & type_index)168 VariantBinaryOpFn* GetBinaryOpFn(VariantBinaryOp op, StringPiece device,
169 const TypeIndex& type_index) {
170 auto found = binary_op_fns.find({op, device, type_index});
171 if (found == binary_op_fns.end()) return nullptr;
172 return &found->second;
173 }
174
175 // Get a pointer to a global UnaryVariantOpRegistry object
Global()176 static UnaryVariantOpRegistry* Global() {
177 static UnaryVariantOpRegistry* global_unary_variant_op_registry =
178 new UnaryVariantOpRegistry;
179 return global_unary_variant_op_registry;
180 }
181
182 // Get a pointer to a global persistent string storage object.
183 // ISO/IEC C++ working draft N4296 clarifies that insertion into an
184 // std::unordered_set does not invalidate memory locations of
185 // *values* inside the set (though it may invalidate existing
186 // iterators). In other words, one may safely point a StringPiece to
187 // a value in the set without that StringPiece being invalidated by
188 // future insertions.
189 static std::unordered_set<string>* PersistentStringStorage();
190
191 private:
192 struct TypeIndexHash {
operatorTypeIndexHash193 std::size_t operator()(const TypeIndex& x) const { return x.hash_code(); }
194 };
195
196 gtl::FlatMap<StringPiece, VariantDecodeFn, StringPieceHasher> decode_fns;
197
198 // Map std::pair<Direction, type_name> to function.
199 struct PairHash {
200 template <typename Direction>
operatorPairHash201 std::size_t operator()(const std::pair<Direction, TypeIndex>& x) const {
202 // The hash of an enum is just its value as a std::size_t.
203 std::size_t ret = static_cast<std::size_t>(std::get<0>(x));
204 ret = Hash64Combine(ret, std::get<1>(x).hash_code());
205 return ret;
206 }
207 };
208
209 gtl::FlatMap<std::pair<VariantDeviceCopyDirection, TypeIndex>,
210 AsyncVariantDeviceCopyFn, PairHash>
211 device_copy_fns;
212
213 // Map std::tuple<Op, device, type_name> to function.
214
215 // this breaks by falling victim to "too perfect forwarding"
216 // see https://stackoverflow.com/questions/44475317/variadic-template-issue
217 // and references therein
218 template <typename Op>
219 struct FuncTuple {
FuncTupleFuncTuple220 FuncTuple(const Op& op, const StringPiece& dev, const TypeIndex& type_index)
221 : op_type_(op), device_(dev), type_index_(type_index) {}
222 Op op_type_;
223 StringPiece device_;
224 TypeIndex type_index_;
225 };
226 // friend declaration for operator==
227 // needed for clang
228 template <typename Op>
229 friend bool operator==(const FuncTuple<Op>& l, const FuncTuple<Op>& r);
230 struct TupleHash {
231 template <typename Op>
operatorTupleHash232 std::size_t operator()(
233 const std::tuple<Op, StringPiece, TypeIndex>& x) const {
234 // The hash of an enum is just its value as a std::size_t.
235 std::size_t ret = static_cast<std::size_t>(std::get<0>(x));
236 ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x)));
237 ret = Hash64Combine(ret, std::get<2>(x).hash_code());
238 return ret;
239 }
240
241 template <typename Op>
operatorTupleHash242 std::size_t operator()(const FuncTuple<Op>& x) const {
243 // The hash of an enum is just its value as a std::size_t.
244 std::size_t ret = static_cast<std::size_t>(x.op_type_);
245 ret = Hash64Combine(ret, sp_hasher_(x.device_));
246 ret = Hash64Combine(ret, x.type_index_.hash_code());
247 return ret;
248 }
249 StringPieceHasher sp_hasher_;
250 };
251 gtl::FlatMap<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn, TupleHash>
252 unary_op_fns;
253 gtl::FlatMap<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn, TupleHash>
254 binary_op_fns;
255
256 // Find or insert a string into a persistent string storage
257 // container; return the StringPiece pointing to the permanent string
258 // location.
GetPersistentStringPiece(const std::string & str)259 static StringPiece GetPersistentStringPiece(const std::string& str) {
260 const auto string_storage = PersistentStringStorage();
261 auto found = string_storage->find(str);
262 if (found == string_storage->end()) {
263 auto inserted = string_storage->insert(str);
264 return StringPiece(*inserted.first);
265 } else {
266 return StringPiece(*found);
267 }
268 }
269 };
270 template <typename Op>
271 inline bool operator==(const UnaryVariantOpRegistry::FuncTuple<Op>& lhs,
272 const UnaryVariantOpRegistry::FuncTuple<Op>& rhs) {
273 return (lhs.op_type_ == rhs.op_type_) && (lhs.device_ == rhs.device_) &&
274 (lhs.type_index_ == rhs.type_index_);
275 }
276
277 // Decodes the Variant whose data_type has a registered decode
278 // function. Returns an Internal error if the Variant does not have a
279 // registered decode function, or if the decoding function fails.
280 //
281 // REQUIRES:
282 // variant is not null.
283 //
284 bool DecodeUnaryVariant(Variant* variant);
285
286 // Copies a variant between CPU<->GPU, or between GPU<->GPU.
287 // The variant 'from' must have a registered DeviceCopyFn for the
288 // given direction. The returned variant 'to' will have
289 // (some subset of its) tensors stored on destination according to the
290 // registered DeviceCopyFn function for the given direction. Returns
291 // an Internal error if the Variant does not have a registered
292 // DeviceCopyFn function for the given direction, or if initiating the
293 // copy fails.
294 //
295 // REQUIRES:
296 // 'to' is not null.
297 //
298 Status VariantDeviceCopy(
299 const VariantDeviceCopyDirection direction, const Variant& from,
300 Variant* to,
301 const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn);
302
303 // Sets *v_out = unary_op(v). The variant v must have a registered
304 // UnaryOp function for the given Device. Returns an Internal error
305 // if v does not have a registered unary_op function for this device, or if
306 // UnaryOp fails.
307 //
308 // REQUIRES:
309 // v_out is not null.
310 //
311 template <typename Device>
UnaryOpVariant(OpKernelContext * ctx,VariantUnaryOp op,const Variant & v,Variant * v_out)312 Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v,
313 Variant* v_out) {
314 const std::string& device = DeviceName<Device>::value;
315 UnaryVariantOpRegistry::VariantUnaryOpFn* unary_op_fn =
316 UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeId());
317 if (unary_op_fn == nullptr) {
318 return errors::Internal("No unary variant unary_op function found for op ",
319 VariantUnaryOpToString(op),
320 " Variant type_name: ", v.TypeName(),
321 " for device type: ", device);
322 }
323 return (*unary_op_fn)(ctx, v, v_out);
324 }
325
326 // Sets *out = binary_op(a, b). The variants a and b must be the same type
327 // and have a registered binary_op function for the given Device. Returns an
328 // Internal error if a and b are not the same type_name or if
329 // if a does not have a registered op function for this device, or if
330 // BinaryOp fails.
331 //
332 // REQUIRES:
333 // out is not null.
334 //
335 template <typename Device>
BinaryOpVariants(OpKernelContext * ctx,VariantBinaryOp op,const Variant & a,const Variant & b,Variant * out)336 Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op,
337 const Variant& a, const Variant& b, Variant* out) {
338 if (a.TypeId() != b.TypeId()) {
339 return errors::Internal(
340 "BinaryOpVariants: Variants a and b have different "
341 "type ids. Type names: '",
342 a.TypeName(), "' vs. '", b.TypeName(), "'");
343 }
344 const std::string& device = DeviceName<Device>::value;
345 UnaryVariantOpRegistry::VariantBinaryOpFn* binary_op_fn =
346 UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeId());
347 if (binary_op_fn == nullptr) {
348 return errors::Internal("No unary variant binary_op function found for op ",
349 VariantBinaryOpToString(op),
350 " Variant type_name: '", a.TypeName(),
351 "' for device type: ", device);
352 }
353 return (*binary_op_fn)(ctx, a, b, out);
354 }
355
356 namespace variant_op_registry_fn_registration {
357
358 template <typename T>
359 class UnaryVariantDecodeRegistration {
360 public:
UnaryVariantDecodeRegistration(const std::string & type_name)361 UnaryVariantDecodeRegistration(const std::string& type_name) {
362 // The Variant is passed by pointer because it should be
363 // mutable: get below may Decode the variant, which
364 // is a self-mutating behavior. The variant is not modified in
365 // any other way.
366 UnaryVariantOpRegistry::Global()->RegisterDecodeFn(
367 type_name, [type_name](Variant* v) -> bool {
368 DCHECK_NE(v, nullptr);
369 VariantTensorDataProto* t = v->get<VariantTensorDataProto>();
370 if (t == nullptr) {
371 return false;
372 }
373 Variant decoded = T();
374 VariantTensorData data(std::move(*t));
375 if (!decoded.Decode(std::move(data))) {
376 return false;
377 }
378 std::swap(decoded, *v);
379 return true;
380 });
381 }
382 };
383
384 template <typename T>
385 class UnaryVariantDeviceCopyRegistration {
386 public:
387 typedef std::function<Status(const T& t, T* t_out,
388 UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn)>
389 LocalVariantDeviceCopyFn;
UnaryVariantDeviceCopyRegistration(const VariantDeviceCopyDirection direction,const TypeIndex & type_index,const LocalVariantDeviceCopyFn & device_copy_fn)390 UnaryVariantDeviceCopyRegistration(
391 const VariantDeviceCopyDirection direction, const TypeIndex& type_index,
392 const LocalVariantDeviceCopyFn& device_copy_fn) {
393 const std::string type_index_name =
394 port::MaybeAbiDemangle(type_index.name());
395 UnaryVariantOpRegistry::Global()->RegisterDeviceCopyFn(
396 direction, type_index,
397 [type_index_name, device_copy_fn](
398 const Variant& from, Variant* to,
399 UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn
400 device_copy_tensor_fn) -> Status {
401 DCHECK_NE(to, nullptr);
402 *to = T();
403 if (from.get<T>() == nullptr) {
404 return errors::Internal(
405 "VariantCopyToGPUFn: Could not access object, type_index: ",
406 type_index_name);
407 }
408 const T& t = *from.get<T>();
409 T* t_out = to->get<T>();
410 return device_copy_fn(t, t_out, device_copy_tensor_fn);
411 });
412 }
413 };
414
415 template <typename T>
416 class UnaryVariantUnaryOpRegistration {
417 typedef std::function<Status(OpKernelContext* ctx, const T& t, T* t_out)>
418 LocalVariantUnaryOpFn;
419
420 public:
UnaryVariantUnaryOpRegistration(VariantUnaryOp op,const std::string & device,const TypeIndex & type_index,const LocalVariantUnaryOpFn & unary_op_fn)421 UnaryVariantUnaryOpRegistration(VariantUnaryOp op, const std::string& device,
422 const TypeIndex& type_index,
423 const LocalVariantUnaryOpFn& unary_op_fn) {
424 const std::string type_index_name =
425 port::MaybeAbiDemangle(type_index.name());
426 UnaryVariantOpRegistry::Global()->RegisterUnaryOpFn(
427 op, device, type_index,
428 [type_index_name, unary_op_fn](OpKernelContext* ctx, const Variant& v,
429 Variant* v_out) -> Status {
430 DCHECK_NE(v_out, nullptr);
431 *v_out = T();
432 if (v.get<T>() == nullptr) {
433 return errors::Internal(
434 "VariantUnaryOpFn: Could not access object, type_index: ",
435 type_index_name);
436 }
437 const T& t = *v.get<T>();
438 T* t_out = v_out->get<T>();
439 return unary_op_fn(ctx, t, t_out);
440 });
441 }
442 };
443
444 template <typename T>
445 class UnaryVariantBinaryOpRegistration {
446 typedef std::function<Status(OpKernelContext* ctx, const T& a, const T& b,
447 T* out)>
448 LocalVariantBinaryOpFn;
449
450 public:
UnaryVariantBinaryOpRegistration(VariantBinaryOp op,const std::string & device,const TypeIndex & type_index,const LocalVariantBinaryOpFn & binary_op_fn)451 UnaryVariantBinaryOpRegistration(VariantBinaryOp op,
452 const std::string& device,
453 const TypeIndex& type_index,
454 const LocalVariantBinaryOpFn& binary_op_fn) {
455 const std::string type_index_name =
456 port::MaybeAbiDemangle(type_index.name());
457 UnaryVariantOpRegistry::Global()->RegisterBinaryOpFn(
458 op, device, type_index,
459 [type_index_name, binary_op_fn](OpKernelContext* ctx, const Variant& a,
460 const Variant& b,
461 Variant* out) -> Status {
462 DCHECK_NE(out, nullptr);
463 *out = T();
464 if (a.get<T>() == nullptr) {
465 return errors::Internal(
466 "VariantBinaryOpFn: Could not access object 'a', type_index: ",
467 type_index_name);
468 }
469 if (b.get<T>() == nullptr) {
470 return errors::Internal(
471 "VariantBinaryOpFn: Could not access object 'b', type_index: ",
472 type_index_name);
473 }
474 const T& t_a = *a.get<T>();
475 const T& t_b = *b.get<T>();
476 T* t_out = out->get<T>();
477 return binary_op_fn(ctx, t_a, t_b, t_out);
478 });
479 }
480 };
481
482 }; // namespace variant_op_registry_fn_registration
483
484 // Register a unary decode variant function for the given type.
485 #define REGISTER_UNARY_VARIANT_DECODE_FUNCTION(T, type_name) \
486 REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ_HELPER(__COUNTER__, T, type_name)
487
488 #define REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ_HELPER(ctr, T, type_name) \
489 REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ(ctr, T, type_name)
490
491 #define REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ(ctr, T, type_name) \
492 static ::tensorflow::variant_op_registry_fn_registration:: \
493 UnaryVariantDecodeRegistration<T> \
494 register_unary_variant_op_decoder_fn_##ctr(type_name)
495
496 // ****** NOTE ******
497 // FOR INTERNAL USE ONLY. IF YOU USE THIS WE MAY BREAK YOUR CODE.
498 // ****** NOTE ******
499 //
500 // Register a device copy variant function for the given copy
501 // direction and type; where direction is the enum
502 // VariantDeviceCopyDirection, and the device_copy_fn has signature:
503 //
504 // Status device_copy_fn(
505 // const T& t, T* t_out,
506 // const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copier);
507 //
508 // And device_copy_fn calls copier 0 or more times. For details on
509 // the behavior of the copier function, see the comments at the
510 // declaration of UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn.
511 //
512 // Note, the device_copy_fn may choose to keep some tensors
513 // on host, e.g. by assigning to->tensor = from.tensor (assuming
514 // from.tensor is already on host); or by setting
515 // to->tensor = Tensor(cpu_allocator(), ...)
516 // and manually updating its values.
517 //
518 // If this is the case, the CopyFns for HOST_TO_DEVICE,
519 // DEVICE_TO_HOST, and DEVICE_TO_DEVICE must perform host-to-host
520 // copies in a consistent manner. For example, one must always
521 // manually copy any "always on host" tensors in all directions instead of e.g.
522 // - performing a host-to-host copy in one direction,
523 // - using the provided copier function in the reverse direction.
524 // Doing the latter will cause program failures.
525 //
526 // ****** NOTE ******
527 // FOR INTERNAL USE ONLY. IF YOU USE THIS WE MAY BREAK YOUR CODE.
528 // ****** NOTE ******
529 #define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(T, direction, \
530 device_copy_fn) \
531 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \
532 __COUNTER__, T, direction, TypeIndex::Make<T>(), device_copy_fn)
533
534 #define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \
535 ctr, T, direction, type_index, device_copy_fn) \
536 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \
537 ctr, T, direction, type_index, device_copy_fn)
538
539 #define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \
540 ctr, T, direction, type_index, device_copy_fn) \
541 static variant_op_registry_fn_registration:: \
542 UnaryVariantDeviceCopyRegistration<T> \
543 register_unary_variant_op_device_copy_fn_##ctr( \
544 direction, type_index, device_copy_fn)
545
546 // Register a unary unary_op variant function with the signature:
547 // Status UnaryOpFn(OpKernelContext* ctx, const T& t, T* t_out);
548 // to Variants having TypeIndex type_index, for device string device,
549 // for UnaryVariantOp enum op.
550 #define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T, \
551 unary_op_function) \
552 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
553 __COUNTER__, op, device, T, TypeIndex::Make<T>(), unary_op_function)
554
555 #define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
556 ctr, op, device, T, type_index, unary_op_function) \
557 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, \
558 type_index, unary_op_function)
559
560 #define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ( \
561 ctr, op, device, T, type_index, unary_op_function) \
562 static ::tensorflow::variant_op_registry_fn_registration:: \
563 UnaryVariantUnaryOpRegistration<T> \
564 register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \
565 unary_op_function)
566
567 // Register a binary_op variant function with the signature:
568 // Status BinaryOpFn(OpKernelContext* ctx, const T& a, const T& b, T* out);
569 // to Variants having TypeIndex type_index, for device string device,
570 // for BinaryVariantOp enum OP.
571 #define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T, \
572 binary_op_function) \
573 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
574 __COUNTER__, op, device, T, TypeIndex::Make<T>(), binary_op_function)
575
576 #define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
577 ctr, op, device, T, type_index, binary_op_function) \
578 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \
579 ctr, op, device, T, type_index, binary_op_function)
580
581 #define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \
582 ctr, op, device, T, type_index, binary_op_function) \
583 static ::tensorflow::variant_op_registry_fn_registration:: \
584 UnaryVariantBinaryOpRegistration<T> \
585 register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \
586 binary_op_function)
587
588 } // end namespace tensorflow
589
590 #endif // TENSORFLOW_CORE_FRAMEWORK_VARIANT_OP_REGISTRY_H_
591