• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_VARIANT_OP_REGISTRY_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_VARIANT_OP_REGISTRY_H_
18 
19 #include <string>
20 #include <unordered_set>
21 #include <vector>
22 
23 #define EIGEN_USE_THREADS
24 
25 #include "tensorflow/core/framework/tensor.pb.h"
26 #include "tensorflow/core/framework/type_index.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/framework/variant.h"
29 #include "tensorflow/core/framework/variant_encode_decode.h"
30 #include "tensorflow/core/lib/gtl/flatmap.h"
31 #include "tensorflow/core/lib/hash/hash.h"
32 #include "tensorflow/core/platform/abi.h"
33 
34 namespace tensorflow {
35 
36 class OpKernelContext;
37 // A global UnaryVariantOpRegistry is used to hold callback functions
38 // for different variant types.  To be used by ShapeOp, RankOp, and
39 // SizeOp, decoding, etc.
40 
41 enum VariantUnaryOp {
42   INVALID_VARIANT_UNARY_OP = 0,
43   ZEROS_LIKE_VARIANT_UNARY_OP = 1,
44   CONJ_VARIANT_UNARY_OP = 2,
45 };
46 
47 const char* VariantUnaryOpToString(VariantUnaryOp op);
48 
49 enum VariantBinaryOp {
50   INVALID_VARIANT_BINARY_OP = 0,
51   ADD_VARIANT_BINARY_OP = 1,
52 };
53 
54 const char* VariantBinaryOpToString(VariantBinaryOp op);
55 
56 enum VariantDeviceCopyDirection {
57   INVALID_DEVICE_COPY_DIRECTION = 0,
58   HOST_TO_DEVICE = 1,
59   DEVICE_TO_HOST = 2,
60   DEVICE_TO_DEVICE = 3,
61 };
62 
63 class UnaryVariantOpRegistry {
64  public:
65   typedef std::function<bool(Variant*)> VariantDecodeFn;
66   typedef std::function<Status(OpKernelContext*, const Variant&, Variant*)>
67       VariantUnaryOpFn;
68   typedef std::function<Status(OpKernelContext*, const Variant&, const Variant&,
69                                Variant*)>
70       VariantBinaryOpFn;
71 
72   // An AsyncTensorDeviceCopyFn is a function provided to
73   // the user-provided DeviceCopyFn callback as the third argument ("copier").
74   //
75   // Expected inputs:
76   //   from: A Tensor on the host (if performing cpu->gpu copy), or
77   //         device (if performing gpu->cpu or gpu->gpu copy).
78   //   to: An empty/uninitialized tensor.  It will be updated upon
79   //       successful return of the function with the correct dtype and shape.
80   //       However, the copied data will not be available until the compute
81   //       stream has been synchronized.
82   //
83   // Returns:
84   //   The status upon memory allocation / initialization of the
85   //   "to" tensor, and enqueue of the copy onto the compute stream.
86   //   Any failure of the copy itself will update the underlying
87   //   stream status and propagate through the runtime independent
88   //   of the caller.
89   typedef std::function<Status(const Tensor& from, Tensor* to)>
90       AsyncTensorDeviceCopyFn;
91 
92   // The AsyncVariantDeviceCopyFn is the signature of the 'device_copy_fn'
93   // expected to be passed to the registration macro
94   // INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION.
95   typedef std::function<Status(const Variant& from, Variant* to,
96                                AsyncTensorDeviceCopyFn copy_fn)>
97       AsyncVariantDeviceCopyFn;
98 
99   // Add a decode function to the registry.
100   void RegisterDecodeFn(const std::string& type_name,
101                         const VariantDecodeFn& decode_fn);
102 
103   // Returns nullptr if no decode function was found for the given TypeName.
104   VariantDecodeFn* GetDecodeFn(StringPiece type_name);
105 
106   // Add a copy-to-GPU function to the registry.
RegisterDeviceCopyFn(const VariantDeviceCopyDirection direction,const TypeIndex & type_index,const AsyncVariantDeviceCopyFn & device_copy_fn)107   void RegisterDeviceCopyFn(const VariantDeviceCopyDirection direction,
108                             const TypeIndex& type_index,
109                             const AsyncVariantDeviceCopyFn& device_copy_fn) {
110     AsyncVariantDeviceCopyFn* existing = GetDeviceCopyFn(direction, type_index);
111     CHECK_EQ(existing, nullptr)
112         << "UnaryVariantDeviceCopy for direction: " << direction
113         << " and type_index: " << port::MaybeAbiDemangle(type_index.name())
114         << " already registered";
115     device_copy_fns.insert(
116         std::pair<std::pair<VariantDeviceCopyDirection, TypeIndex>,
117                   AsyncVariantDeviceCopyFn>(
118             std::make_pair(direction, type_index), device_copy_fn));
119   }
120 
121   // Returns nullptr if no copy function was found for the given
122   // TypeName and direction.
GetDeviceCopyFn(const VariantDeviceCopyDirection direction,const TypeIndex & type_index)123   AsyncVariantDeviceCopyFn* GetDeviceCopyFn(
124       const VariantDeviceCopyDirection direction, const TypeIndex& type_index) {
125     auto found = device_copy_fns.find(std::make_pair(direction, type_index));
126     if (found == device_copy_fns.end()) return nullptr;
127     return &found->second;
128   }
129 
130   // Add a unary op function to the registry.
RegisterUnaryOpFn(VariantUnaryOp op,const std::string & device,const TypeIndex & type_index,const VariantUnaryOpFn & unary_op_fn)131   void RegisterUnaryOpFn(VariantUnaryOp op, const std::string& device,
132                          const TypeIndex& type_index,
133                          const VariantUnaryOpFn& unary_op_fn) {
134     VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_index);
135     CHECK_EQ(existing, nullptr)
136         << "Unary VariantUnaryOpFn for type_index: "
137         << port::MaybeAbiDemangle(type_index.name())
138         << " already registered for device type: " << device;
139     unary_op_fns.insert(std::pair<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn>(
140         {op, GetPersistentStringPiece(device), type_index}, unary_op_fn));
141   }
142 
143   // Returns nullptr if no unary op function was found for the given
144   // op, device, and TypeName.
GetUnaryOpFn(VariantUnaryOp op,StringPiece device,const TypeIndex & type_index)145   VariantUnaryOpFn* GetUnaryOpFn(VariantUnaryOp op, StringPiece device,
146                                  const TypeIndex& type_index) {
147     auto found = unary_op_fns.find({op, device, type_index});
148     if (found == unary_op_fns.end()) return nullptr;
149     return &found->second;
150   }
151 
152   // Add a binary op function to the registry.
RegisterBinaryOpFn(VariantBinaryOp op,const std::string & device,const TypeIndex & type_index,const VariantBinaryOpFn & add_fn)153   void RegisterBinaryOpFn(VariantBinaryOp op, const std::string& device,
154                           const TypeIndex& type_index,
155                           const VariantBinaryOpFn& add_fn) {
156     VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_index);
157     CHECK_EQ(existing, nullptr)
158         << "Unary VariantBinaryOpFn for type_index: "
159         << port::MaybeAbiDemangle(type_index.name())
160         << " already registered for device type: " << device;
161     binary_op_fns.insert(
162         std::pair<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn>(
163             {op, GetPersistentStringPiece(device), type_index}, add_fn));
164   }
165 
166   // Returns nullptr if no binary op function was found for the given
167   // op, device and TypeName.
GetBinaryOpFn(VariantBinaryOp op,StringPiece device,const TypeIndex & type_index)168   VariantBinaryOpFn* GetBinaryOpFn(VariantBinaryOp op, StringPiece device,
169                                    const TypeIndex& type_index) {
170     auto found = binary_op_fns.find({op, device, type_index});
171     if (found == binary_op_fns.end()) return nullptr;
172     return &found->second;
173   }
174 
175   // Get a pointer to a global UnaryVariantOpRegistry object
Global()176   static UnaryVariantOpRegistry* Global() {
177     static UnaryVariantOpRegistry* global_unary_variant_op_registry =
178         new UnaryVariantOpRegistry;
179     return global_unary_variant_op_registry;
180   }
181 
182   // Get a pointer to a global persistent string storage object.
183   // ISO/IEC C++ working draft N4296 clarifies that insertion into an
184   // std::unordered_set does not invalidate memory locations of
185   // *values* inside the set (though it may invalidate existing
186   // iterators).  In other words, one may safely point a StringPiece to
187   // a value in the set without that StringPiece being invalidated by
188   // future insertions.
189   static std::unordered_set<string>* PersistentStringStorage();
190 
191  private:
192   struct TypeIndexHash {
operatorTypeIndexHash193     std::size_t operator()(const TypeIndex& x) const { return x.hash_code(); }
194   };
195 
196   gtl::FlatMap<StringPiece, VariantDecodeFn, StringPieceHasher> decode_fns;
197 
198   // Map std::pair<Direction, type_name> to function.
199   struct PairHash {
200     template <typename Direction>
operatorPairHash201     std::size_t operator()(const std::pair<Direction, TypeIndex>& x) const {
202       // The hash of an enum is just its value as a std::size_t.
203       std::size_t ret = static_cast<std::size_t>(std::get<0>(x));
204       ret = Hash64Combine(ret, std::get<1>(x).hash_code());
205       return ret;
206     }
207   };
208 
209   gtl::FlatMap<std::pair<VariantDeviceCopyDirection, TypeIndex>,
210                AsyncVariantDeviceCopyFn, PairHash>
211       device_copy_fns;
212 
213   // Map std::tuple<Op, device, type_name> to function.
214 
215   // this breaks by falling victim to "too perfect forwarding"
216   // see https://stackoverflow.com/questions/44475317/variadic-template-issue
217   // and references therein
218   template <typename Op>
219   struct FuncTuple {
FuncTupleFuncTuple220     FuncTuple(const Op& op, const StringPiece& dev, const TypeIndex& type_index)
221         : op_type_(op), device_(dev), type_index_(type_index) {}
222     Op op_type_;
223     StringPiece device_;
224     TypeIndex type_index_;
225   };
226   // friend declaration for operator==
227   // needed for clang
228   template <typename Op>
229   friend bool operator==(const FuncTuple<Op>& l, const FuncTuple<Op>& r);
230   struct TupleHash {
231     template <typename Op>
operatorTupleHash232     std::size_t operator()(
233         const std::tuple<Op, StringPiece, TypeIndex>& x) const {
234       // The hash of an enum is just its value as a std::size_t.
235       std::size_t ret = static_cast<std::size_t>(std::get<0>(x));
236       ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x)));
237       ret = Hash64Combine(ret, std::get<2>(x).hash_code());
238       return ret;
239     }
240 
241     template <typename Op>
operatorTupleHash242     std::size_t operator()(const FuncTuple<Op>& x) const {
243       // The hash of an enum is just its value as a std::size_t.
244       std::size_t ret = static_cast<std::size_t>(x.op_type_);
245       ret = Hash64Combine(ret, sp_hasher_(x.device_));
246       ret = Hash64Combine(ret, x.type_index_.hash_code());
247       return ret;
248     }
249     StringPieceHasher sp_hasher_;
250   };
251   gtl::FlatMap<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn, TupleHash>
252       unary_op_fns;
253   gtl::FlatMap<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn, TupleHash>
254       binary_op_fns;
255 
256   // Find or insert a string into a persistent string storage
257   // container; return the StringPiece pointing to the permanent string
258   // location.
GetPersistentStringPiece(const std::string & str)259   static StringPiece GetPersistentStringPiece(const std::string& str) {
260     const auto string_storage = PersistentStringStorage();
261     auto found = string_storage->find(str);
262     if (found == string_storage->end()) {
263       auto inserted = string_storage->insert(str);
264       return StringPiece(*inserted.first);
265     } else {
266       return StringPiece(*found);
267     }
268   }
269 };
270 template <typename Op>
271 inline bool operator==(const UnaryVariantOpRegistry::FuncTuple<Op>& lhs,
272                        const UnaryVariantOpRegistry::FuncTuple<Op>& rhs) {
273   return (lhs.op_type_ == rhs.op_type_) && (lhs.device_ == rhs.device_) &&
274          (lhs.type_index_ == rhs.type_index_);
275 }
276 
277 // Decodes the Variant whose data_type has a registered decode
278 // function.  Returns an Internal error if the Variant does not have a
279 // registered decode function, or if the decoding function fails.
280 //
281 // REQUIRES:
282 //   variant is not null.
283 //
284 bool DecodeUnaryVariant(Variant* variant);
285 
286 // Copies a variant between CPU<->GPU, or between GPU<->GPU.
287 // The variant 'from' must have a registered DeviceCopyFn for the
288 // given direction.  The returned variant 'to' will have
289 // (some subset of its) tensors stored on destination according to the
290 // registered DeviceCopyFn function for the given direction.  Returns
291 // an Internal error if the Variant does not have a registered
292 // DeviceCopyFn function for the given direction, or if initiating the
293 // copy fails.
294 //
295 // REQUIRES:
296 //   'to' is not null.
297 //
298 Status VariantDeviceCopy(
299     const VariantDeviceCopyDirection direction, const Variant& from,
300     Variant* to,
301     const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn);
302 
303 // Sets *v_out = unary_op(v).  The variant v must have a registered
304 // UnaryOp function for the given Device.  Returns an Internal error
305 // if v does not have a registered unary_op function for this device, or if
306 // UnaryOp fails.
307 //
308 // REQUIRES:
309 //   v_out is not null.
310 //
311 template <typename Device>
UnaryOpVariant(OpKernelContext * ctx,VariantUnaryOp op,const Variant & v,Variant * v_out)312 Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v,
313                       Variant* v_out) {
314   const std::string& device = DeviceName<Device>::value;
315   UnaryVariantOpRegistry::VariantUnaryOpFn* unary_op_fn =
316       UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeId());
317   if (unary_op_fn == nullptr) {
318     return errors::Internal("No unary variant unary_op function found for op ",
319                             VariantUnaryOpToString(op),
320                             " Variant type_name: ", v.TypeName(),
321                             " for device type: ", device);
322   }
323   return (*unary_op_fn)(ctx, v, v_out);
324 }
325 
326 // Sets *out = binary_op(a, b).  The variants a and b must be the same type
327 // and have a registered binary_op function for the given Device.  Returns an
328 // Internal error if a and b are not the same type_name or if
329 // if a does not have a registered op function for this device, or if
330 // BinaryOp fails.
331 //
332 // REQUIRES:
333 //   out is not null.
334 //
335 template <typename Device>
BinaryOpVariants(OpKernelContext * ctx,VariantBinaryOp op,const Variant & a,const Variant & b,Variant * out)336 Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op,
337                         const Variant& a, const Variant& b, Variant* out) {
338   if (a.TypeId() != b.TypeId()) {
339     return errors::Internal(
340         "BinaryOpVariants: Variants a and b have different "
341         "type ids.  Type names: '",
342         a.TypeName(), "' vs. '", b.TypeName(), "'");
343   }
344   const std::string& device = DeviceName<Device>::value;
345   UnaryVariantOpRegistry::VariantBinaryOpFn* binary_op_fn =
346       UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeId());
347   if (binary_op_fn == nullptr) {
348     return errors::Internal("No unary variant binary_op function found for op ",
349                             VariantBinaryOpToString(op),
350                             " Variant type_name: '", a.TypeName(),
351                             "' for device type: ", device);
352   }
353   return (*binary_op_fn)(ctx, a, b, out);
354 }
355 
356 namespace variant_op_registry_fn_registration {
357 
358 template <typename T>
359 class UnaryVariantDecodeRegistration {
360  public:
UnaryVariantDecodeRegistration(const std::string & type_name)361   UnaryVariantDecodeRegistration(const std::string& type_name) {
362     // The Variant is passed by pointer because it should be
363     // mutable: get below may Decode the variant, which
364     // is a self-mutating behavior.  The variant is not modified in
365     // any other way.
366     UnaryVariantOpRegistry::Global()->RegisterDecodeFn(
367         type_name, [type_name](Variant* v) -> bool {
368           DCHECK_NE(v, nullptr);
369           VariantTensorDataProto* t = v->get<VariantTensorDataProto>();
370           if (t == nullptr) {
371             return false;
372           }
373           Variant decoded = T();
374           VariantTensorData data(std::move(*t));
375           if (!decoded.Decode(std::move(data))) {
376             return false;
377           }
378           std::swap(decoded, *v);
379           return true;
380         });
381   }
382 };
383 
384 template <typename T>
385 class UnaryVariantDeviceCopyRegistration {
386  public:
387   typedef std::function<Status(const T& t, T* t_out,
388                                UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn)>
389       LocalVariantDeviceCopyFn;
UnaryVariantDeviceCopyRegistration(const VariantDeviceCopyDirection direction,const TypeIndex & type_index,const LocalVariantDeviceCopyFn & device_copy_fn)390   UnaryVariantDeviceCopyRegistration(
391       const VariantDeviceCopyDirection direction, const TypeIndex& type_index,
392       const LocalVariantDeviceCopyFn& device_copy_fn) {
393     const std::string type_index_name =
394         port::MaybeAbiDemangle(type_index.name());
395     UnaryVariantOpRegistry::Global()->RegisterDeviceCopyFn(
396         direction, type_index,
397         [type_index_name, device_copy_fn](
398             const Variant& from, Variant* to,
399             UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn
400                 device_copy_tensor_fn) -> Status {
401           DCHECK_NE(to, nullptr);
402           *to = T();
403           if (from.get<T>() == nullptr) {
404             return errors::Internal(
405                 "VariantCopyToGPUFn: Could not access object, type_index: ",
406                 type_index_name);
407           }
408           const T& t = *from.get<T>();
409           T* t_out = to->get<T>();
410           return device_copy_fn(t, t_out, device_copy_tensor_fn);
411         });
412   }
413 };
414 
415 template <typename T>
416 class UnaryVariantUnaryOpRegistration {
417   typedef std::function<Status(OpKernelContext* ctx, const T& t, T* t_out)>
418       LocalVariantUnaryOpFn;
419 
420  public:
UnaryVariantUnaryOpRegistration(VariantUnaryOp op,const std::string & device,const TypeIndex & type_index,const LocalVariantUnaryOpFn & unary_op_fn)421   UnaryVariantUnaryOpRegistration(VariantUnaryOp op, const std::string& device,
422                                   const TypeIndex& type_index,
423                                   const LocalVariantUnaryOpFn& unary_op_fn) {
424     const std::string type_index_name =
425         port::MaybeAbiDemangle(type_index.name());
426     UnaryVariantOpRegistry::Global()->RegisterUnaryOpFn(
427         op, device, type_index,
428         [type_index_name, unary_op_fn](OpKernelContext* ctx, const Variant& v,
429                                        Variant* v_out) -> Status {
430           DCHECK_NE(v_out, nullptr);
431           *v_out = T();
432           if (v.get<T>() == nullptr) {
433             return errors::Internal(
434                 "VariantUnaryOpFn: Could not access object, type_index: ",
435                 type_index_name);
436           }
437           const T& t = *v.get<T>();
438           T* t_out = v_out->get<T>();
439           return unary_op_fn(ctx, t, t_out);
440         });
441   }
442 };
443 
444 template <typename T>
445 class UnaryVariantBinaryOpRegistration {
446   typedef std::function<Status(OpKernelContext* ctx, const T& a, const T& b,
447                                T* out)>
448       LocalVariantBinaryOpFn;
449 
450  public:
UnaryVariantBinaryOpRegistration(VariantBinaryOp op,const std::string & device,const TypeIndex & type_index,const LocalVariantBinaryOpFn & binary_op_fn)451   UnaryVariantBinaryOpRegistration(VariantBinaryOp op,
452                                    const std::string& device,
453                                    const TypeIndex& type_index,
454                                    const LocalVariantBinaryOpFn& binary_op_fn) {
455     const std::string type_index_name =
456         port::MaybeAbiDemangle(type_index.name());
457     UnaryVariantOpRegistry::Global()->RegisterBinaryOpFn(
458         op, device, type_index,
459         [type_index_name, binary_op_fn](OpKernelContext* ctx, const Variant& a,
460                                         const Variant& b,
461                                         Variant* out) -> Status {
462           DCHECK_NE(out, nullptr);
463           *out = T();
464           if (a.get<T>() == nullptr) {
465             return errors::Internal(
466                 "VariantBinaryOpFn: Could not access object 'a', type_index: ",
467                 type_index_name);
468           }
469           if (b.get<T>() == nullptr) {
470             return errors::Internal(
471                 "VariantBinaryOpFn: Could not access object 'b', type_index: ",
472                 type_index_name);
473           }
474           const T& t_a = *a.get<T>();
475           const T& t_b = *b.get<T>();
476           T* t_out = out->get<T>();
477           return binary_op_fn(ctx, t_a, t_b, t_out);
478         });
479   }
480 };
481 
482 };  // namespace variant_op_registry_fn_registration
483 
484 // Register a unary decode variant function for the given type.
485 #define REGISTER_UNARY_VARIANT_DECODE_FUNCTION(T, type_name) \
486   REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ_HELPER(__COUNTER__, T, type_name)
487 
488 #define REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ_HELPER(ctr, T, type_name) \
489   REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ(ctr, T, type_name)
490 
491 #define REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ(ctr, T, type_name) \
492   static ::tensorflow::variant_op_registry_fn_registration::           \
493       UnaryVariantDecodeRegistration<T>                                \
494           register_unary_variant_op_decoder_fn_##ctr(type_name)
495 
496 // ****** NOTE ******
497 // FOR INTERNAL USE ONLY.  IF YOU USE THIS WE MAY BREAK YOUR CODE.
498 // ****** NOTE ******
499 //
500 // Register a device copy variant function for the given copy
501 // direction and type; where direction is the enum
502 // VariantDeviceCopyDirection, and the device_copy_fn has signature:
503 //
504 //   Status device_copy_fn(
505 //     const T& t, T* t_out,
506 //     const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copier);
507 //
508 // And device_copy_fn calls copier 0 or more times.  For details on
509 // the behavior of the copier function, see the comments at the
510 // declaration of UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn.
511 //
512 // Note, the device_copy_fn may choose to keep some tensors
513 // on host, e.g. by assigning to->tensor = from.tensor (assuming
514 // from.tensor is already on host); or by setting
515 //   to->tensor = Tensor(cpu_allocator(), ...)
516 // and manually updating its values.
517 //
518 // If this is the case, the CopyFns for HOST_TO_DEVICE,
519 // DEVICE_TO_HOST, and DEVICE_TO_DEVICE must perform host-to-host
520 // copies in a consistent manner.  For example, one must always
521 // manually copy any "always on host" tensors in all directions instead of e.g.
522 //   - performing a host-to-host copy in one direction,
523 //   - using the provided copier function in the reverse direction.
524 // Doing the latter will cause program failures.
525 //
526 // ****** NOTE ******
527 // FOR INTERNAL USE ONLY.  IF YOU USE THIS WE MAY BREAK YOUR CODE.
528 // ****** NOTE ******
529 #define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(T, direction,   \
530                                                              device_copy_fn) \
531   INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER(          \
532       __COUNTER__, T, direction, TypeIndex::Make<T>(), device_copy_fn)
533 
534 #define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \
535     ctr, T, direction, type_index, device_copy_fn)                        \
536   INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ(              \
537       ctr, T, direction, type_index, device_copy_fn)
538 
539 #define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \
540     ctr, T, direction, type_index, device_copy_fn)                 \
541   static variant_op_registry_fn_registration::                     \
542       UnaryVariantDeviceCopyRegistration<T>                        \
543           register_unary_variant_op_device_copy_fn_##ctr(          \
544               direction, type_index, device_copy_fn)
545 
546 // Register a unary unary_op variant function with the signature:
547 //    Status UnaryOpFn(OpKernelContext* ctx, const T& t, T* t_out);
548 // to Variants having TypeIndex type_index, for device string device,
549 // for UnaryVariantOp enum op.
550 #define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T,     \
551                                                  unary_op_function) \
552   REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER(             \
553       __COUNTER__, op, device, T, TypeIndex::Make<T>(), unary_op_function)
554 
555 #define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER(       \
556     ctr, op, device, T, type_index, unary_op_function)              \
557   REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, \
558                                                 type_index, unary_op_function)
559 
560 #define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(                       \
561     ctr, op, device, T, type_index, unary_op_function)                       \
562   static ::tensorflow::variant_op_registry_fn_registration::                 \
563       UnaryVariantUnaryOpRegistration<T>                                     \
564           register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \
565                                                      unary_op_function)
566 
567 // Register a binary_op variant function with the signature:
568 //    Status BinaryOpFn(OpKernelContext* ctx, const T& a, const T& b, T* out);
569 // to Variants having TypeIndex type_index, for device string device,
570 // for BinaryVariantOp enum OP.
571 #define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T,      \
572                                                   binary_op_function) \
573   REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER(              \
574       __COUNTER__, op, device, T, TypeIndex::Make<T>(), binary_op_function)
575 
576 #define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
577     ctr, op, device, T, type_index, binary_op_function)        \
578   REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ(              \
579       ctr, op, device, T, type_index, binary_op_function)
580 
581 #define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ(                      \
582     ctr, op, device, T, type_index, binary_op_function)                      \
583   static ::tensorflow::variant_op_registry_fn_registration::                 \
584       UnaryVariantBinaryOpRegistration<T>                                    \
585           register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \
586                                                      binary_op_function)
587 
588 }  // end namespace tensorflow
589 
590 #endif  // TENSORFLOW_CORE_FRAMEWORK_VARIANT_OP_REGISTRY_H_
591