1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_VARIANT_OP_REGISTRY_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_VARIANT_OP_REGISTRY_H_
18
19 #include <string>
20 #include <unordered_set>
21 #include <vector>
22
23 #define EIGEN_USE_THREADS
24
25 #include "tensorflow/core/framework/tensor.pb.h"
26 #include "tensorflow/core/framework/type_index.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/framework/variant.h"
29 #include "tensorflow/core/framework/variant_encode_decode.h"
30 #include "tensorflow/core/lib/gtl/flatmap.h"
31 #include "tensorflow/core/lib/hash/hash.h"
32 #include "tensorflow/core/platform/abi.h"
33
34 namespace tensorflow {
35
36 class OpKernelContext;
37 // A global UnaryVariantOpRegistry is used to hold callback functions
38 // for different variant types. To be used by ShapeOp, RankOp, and
39 // SizeOp, decoding, etc.
40
41 enum VariantUnaryOp {
42 INVALID_VARIANT_UNARY_OP = 0,
43 ZEROS_LIKE_VARIANT_UNARY_OP = 1,
44 CONJ_VARIANT_UNARY_OP = 2,
45 };
46
47 enum VariantBinaryOp {
48 INVALID_VARIANT_BINARY_OP = 0,
49 ADD_VARIANT_BINARY_OP = 1,
50 };
51
52 enum VariantDeviceCopyDirection {
53 INVALID_DEVICE_COPY_DIRECTION = 0,
54 HOST_TO_DEVICE = 1,
55 DEVICE_TO_HOST = 2,
56 DEVICE_TO_DEVICE = 3,
57 };
58
59 class UnaryVariantOpRegistry {
60 public:
61 typedef std::function<bool(Variant*)> VariantDecodeFn;
62 typedef std::function<Status(OpKernelContext*, const Variant&, Variant*)>
63 VariantUnaryOpFn;
64 typedef std::function<Status(OpKernelContext*, const Variant&, const Variant&,
65 Variant*)>
66 VariantBinaryOpFn;
67
68 // An AsyncTensorDeviceCopyFn is a function provided to
69 // the user-provided DeviceCopyFn callback as the third argument ("copier").
70 //
71 // Expected inputs:
72 // from: A Tensor on the host (if performing cpu->gpu copy), or
73 // device (if performing gpu->cpu or gpu->gpu copy).
74 // to: An empty/uninitialized tensor. It will be updated upon
75 // successful return of the function with the correct dtype and shape.
76 // However, the copied data will not be available until the compute
77 // stream has been synchronized.
78 //
79 // Returns:
80 // The status upon memory allocation / initialization of the
81 // "to" tensor, and enqueue of the copy onto the compute stream.
82 // Any failure of the copy itself will update the underlying
83 // stream status and propagate through the runtime independent
84 // of the caller.
85 typedef std::function<Status(const Tensor& from, Tensor* to)>
86 AsyncTensorDeviceCopyFn;
87
88 // The AsyncVariantDeviceCopyFn is the signature of the 'device_copy_fn'
89 // expected to be passed to the registration macro
90 // INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION.
91 typedef std::function<Status(const Variant& from, Variant* to,
92 AsyncTensorDeviceCopyFn copy_fn)>
93 AsyncVariantDeviceCopyFn;
94
95 // Add a decode function to the registry.
96 void RegisterDecodeFn(const string& type_name,
97 const VariantDecodeFn& decode_fn);
98
99 // Returns nullptr if no decode function was found for the given TypeName.
100 VariantDecodeFn* GetDecodeFn(StringPiece type_name);
101
102 // Add a copy-to-GPU function to the registry.
103 void RegisterDeviceCopyFn(const VariantDeviceCopyDirection direction,
104 const TypeIndex& type_index,
105 const AsyncVariantDeviceCopyFn& device_copy_fn);
106
107 // Returns nullptr if no copy function was found for the given
108 // TypeName and direction.
109 AsyncVariantDeviceCopyFn* GetDeviceCopyFn(
110 const VariantDeviceCopyDirection direction, const TypeIndex& type_index);
111
112 // Add a unary op function to the registry.
113 void RegisterUnaryOpFn(VariantUnaryOp op, const string& device,
114 const TypeIndex& type_index,
115 const VariantUnaryOpFn& unary_op_fn);
116
117 // Returns nullptr if no unary op function was found for the given
118 // op, device, and TypeName.
119 VariantUnaryOpFn* GetUnaryOpFn(VariantUnaryOp op, StringPiece device,
120 const TypeIndex& type_index);
121
122 // Add a binary op function to the registry.
123 void RegisterBinaryOpFn(VariantBinaryOp op, const string& device,
124 const TypeIndex& type_index,
125 const VariantBinaryOpFn& add_fn);
126
127 // Returns nullptr if no binary op function was found for the given
128 // op, device and TypeName.
129 VariantBinaryOpFn* GetBinaryOpFn(VariantBinaryOp op, StringPiece device,
130 const TypeIndex& type_index);
131
132 // Get a pointer to a global UnaryVariantOpRegistry object
133 static UnaryVariantOpRegistry* Global();
134
135 // Get a pointer to a global persistent string storage object.
136 // ISO/IEC C++ working draft N4296 clarifies that insertion into an
137 // std::unordered_set does not invalidate memory locations of
138 // *values* inside the set (though it may invalidate existing
139 // iterators). In other words, one may safely point a StringPiece to
140 // a value in the set without that StringPiece being invalidated by
141 // future insertions.
142 static std::unordered_set<string>* PersistentStringStorage();
143
144 private:
145 struct TypeIndexHash {
operatorTypeIndexHash146 std::size_t operator()(const TypeIndex& x) const { return x.hash_code(); }
147 };
148
149 gtl::FlatMap<StringPiece, VariantDecodeFn, StringPieceHasher> decode_fns;
150
151 // Map std::pair<Direction, type_name> to function.
152 struct PairHash {
153 template <typename Direction>
operatorPairHash154 std::size_t operator()(const std::pair<Direction, TypeIndex>& x) const {
155 // The hash of an enum is just its value as a std::size_t.
156 std::size_t ret = static_cast<std::size_t>(std::get<0>(x));
157 ret = Hash64Combine(ret, std::get<1>(x).hash_code());
158 return ret;
159 }
160 };
161
162 gtl::FlatMap<std::pair<VariantDeviceCopyDirection, TypeIndex>,
163 AsyncVariantDeviceCopyFn, PairHash>
164 device_copy_fns;
165
166 // Map std::tuple<Op, device, type_name> to function.
167
168 // this breaks by falling victim to "too perfect forwarding"
169 // see https://stackoverflow.com/questions/44475317/variadic-template-issue
170 // and references therein
171 template <typename Op>
172 struct FuncTuple {
FuncTupleFuncTuple173 FuncTuple(const Op& op, const StringPiece& dev, const TypeIndex& type_index)
174 : op_type_(op), device_(dev), type_index_(type_index) {}
175 Op op_type_;
176 StringPiece device_;
177 TypeIndex type_index_;
178 };
179 // friend declaration for operator==
180 // needed for clang
181 template <typename Op>
182 friend bool operator==(const FuncTuple<Op>& l, const FuncTuple<Op>& r);
183 struct TupleHash {
184 template <typename Op>
operatorTupleHash185 std::size_t operator()(
186 const std::tuple<Op, StringPiece, TypeIndex>& x) const {
187 // The hash of an enum is just its value as a std::size_t.
188 std::size_t ret = static_cast<std::size_t>(std::get<0>(x));
189 ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x)));
190 ret = Hash64Combine(ret, std::get<2>(x).hash_code());
191 return ret;
192 }
193
194 template <typename Op>
operatorTupleHash195 std::size_t operator()(const FuncTuple<Op>& x) const {
196 // The hash of an enum is just its value as a std::size_t.
197 std::size_t ret = static_cast<std::size_t>(x.op_type_);
198 ret = Hash64Combine(ret, sp_hasher_(x.device_));
199 ret = Hash64Combine(ret, x.type_index_.hash_code());
200 return ret;
201 }
202 StringPieceHasher sp_hasher_;
203 };
204 gtl::FlatMap<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn, TupleHash>
205 unary_op_fns;
206 gtl::FlatMap<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn, TupleHash>
207 binary_op_fns;
208
209 // Find or insert a string into a persistent string storage
210 // container; return the StringPiece pointing to the permanent string
211 // location.
GetPersistentStringPiece(const string & str)212 static StringPiece GetPersistentStringPiece(const string& str) {
213 const auto string_storage = PersistentStringStorage();
214 auto found = string_storage->find(str);
215 if (found == string_storage->end()) {
216 auto inserted = string_storage->insert(str);
217 return StringPiece(*inserted.first);
218 } else {
219 return StringPiece(*found);
220 }
221 }
222 };
223 template <typename Op>
224 inline bool operator==(const UnaryVariantOpRegistry::FuncTuple<Op>& lhs,
225 const UnaryVariantOpRegistry::FuncTuple<Op>& rhs) {
226 return (lhs.op_type_ == rhs.op_type_) && (lhs.device_ == rhs.device_) &&
227 (lhs.type_index_ == rhs.type_index_);
228 }
229
230 // Decodes the Variant whose data_type has a registered decode
231 // function. Returns an Internal error if the Variant does not have a
232 // registered decode function, or if the decoding function fails.
233 //
234 // REQUIRES:
235 // variant is not null.
236 //
237 bool DecodeUnaryVariant(Variant* variant);
238
239 // Copies a variant between CPU<->GPU, or between GPU<->GPU.
240 // The variant 'from' must have a registered DeviceCopyFn for the
241 // given direction. The returned variant 'to' will have
242 // (some subset of its) tensors stored on destination according to the
243 // registered DeviceCopyFn function for the given direction. Returns
244 // an Internal error if the Variant does not have a registered
245 // DeviceCopyFn function for the given direction, or if initiating the
246 // copy fails.
247 //
248 // REQUIRES:
249 // 'to' is not null.
250 //
251 Status VariantDeviceCopy(
252 const VariantDeviceCopyDirection direction, const Variant& from,
253 Variant* to,
254 const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn);
255
256 // Sets *v_out = unary_op(v). The variant v must have a registered
257 // UnaryOp function for the given Device. Returns an Internal error
258 // if v does not have a registered unary_op function for this device, or if
259 // UnaryOp fails.
260 //
261 // REQUIRES:
262 // v_out is not null.
263 //
264 template <typename Device>
UnaryOpVariant(OpKernelContext * ctx,VariantUnaryOp op,const Variant & v,Variant * v_out)265 Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v,
266 Variant* v_out) {
267 const string& device = DeviceName<Device>::value;
268 UnaryVariantOpRegistry::VariantUnaryOpFn* unary_op_fn =
269 UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeId());
270 if (unary_op_fn == nullptr) {
271 return errors::Internal(
272 "No unary variant unary_op function found for unary variant op enum: ",
273 op, " Variant type_name: ", v.TypeName(), " for device type: ", device);
274 }
275 return (*unary_op_fn)(ctx, v, v_out);
276 }
277
278 // Sets *out = binary_op(a, b). The variants a and b must be the same type
279 // and have a registered binary_op function for the given Device. Returns an
280 // Internal error if a and b are not the same type_name or if
281 // if a does not have a registered op function for this device, or if
282 // BinaryOp fails.
283 //
284 // REQUIRES:
285 // out is not null.
286 //
287 template <typename Device>
BinaryOpVariants(OpKernelContext * ctx,VariantBinaryOp op,const Variant & a,const Variant & b,Variant * out)288 Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op,
289 const Variant& a, const Variant& b, Variant* out) {
290 if (a.TypeId() != b.TypeId()) {
291 return errors::Internal(
292 "BianryOpVariants: Variants a and b have different "
293 "type ids. Type names: '",
294 a.TypeName(), "' vs. '", b.TypeName(), "'");
295 }
296 const string& device = DeviceName<Device>::value;
297 UnaryVariantOpRegistry::VariantBinaryOpFn* binary_op_fn =
298 UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeId());
299 if (binary_op_fn == nullptr) {
300 return errors::Internal(
301 "No unary variant binary_op function found for binary variant op "
302 "enum: ",
303 op, " Variant type_name: '", a.TypeName(), "' for device type: ",
304 device);
305 }
306 return (*binary_op_fn)(ctx, a, b, out);
307 }
308
309 namespace variant_op_registry_fn_registration {
310
311 template <typename T>
312 class UnaryVariantDecodeRegistration {
313 public:
UnaryVariantDecodeRegistration(const string & type_name)314 UnaryVariantDecodeRegistration(const string& type_name) {
315 // The Variant is passed by pointer because it should be
316 // mutable: get below may Decode the variant, which
317 // is a self-mutating behavior. The variant is not modified in
318 // any other way.
319 UnaryVariantOpRegistry::Global()->RegisterDecodeFn(
320 type_name, [type_name](Variant* v) -> bool {
321 DCHECK_NE(v, nullptr);
322 VariantTensorDataProto* t = v->get<VariantTensorDataProto>();
323 if (t == nullptr) {
324 return false;
325 }
326 Variant decoded = T();
327 VariantTensorData data(std::move(*t));
328 if (!decoded.Decode(std::move(data))) {
329 return false;
330 }
331 std::swap(decoded, *v);
332 return true;
333 });
334 }
335 };
336
337 template <typename T>
338 class UnaryVariantDeviceCopyRegistration {
339 public:
340 typedef std::function<Status(const T& t, T* t_out,
341 UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn)>
342 LocalVariantDeviceCopyFn;
UnaryVariantDeviceCopyRegistration(const VariantDeviceCopyDirection direction,const TypeIndex & type_index,const LocalVariantDeviceCopyFn & device_copy_fn)343 UnaryVariantDeviceCopyRegistration(
344 const VariantDeviceCopyDirection direction, const TypeIndex& type_index,
345 const LocalVariantDeviceCopyFn& device_copy_fn) {
346 const string type_index_name = port::MaybeAbiDemangle(type_index.name());
347 UnaryVariantOpRegistry::Global()->RegisterDeviceCopyFn(
348 direction, type_index,
349 [type_index_name, device_copy_fn](
350 const Variant& from, Variant* to,
351 UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn
352 device_copy_tensor_fn) -> Status {
353 DCHECK_NE(to, nullptr);
354 *to = T();
355 if (from.get<T>() == nullptr) {
356 return errors::Internal(
357 "VariantCopyToGPUFn: Could not access object, type_index: ",
358 type_index_name);
359 }
360 const T& t = *from.get<T>();
361 T* t_out = to->get<T>();
362 return device_copy_fn(t, t_out, device_copy_tensor_fn);
363 });
364 }
365 };
366
367 template <typename T>
368 class UnaryVariantUnaryOpRegistration {
369 typedef std::function<Status(OpKernelContext* ctx, const T& t, T* t_out)>
370 LocalVariantUnaryOpFn;
371
372 public:
UnaryVariantUnaryOpRegistration(VariantUnaryOp op,const string & device,const TypeIndex & type_index,const LocalVariantUnaryOpFn & unary_op_fn)373 UnaryVariantUnaryOpRegistration(VariantUnaryOp op, const string& device,
374 const TypeIndex& type_index,
375 const LocalVariantUnaryOpFn& unary_op_fn) {
376 const string type_index_name = port::MaybeAbiDemangle(type_index.name());
377 UnaryVariantOpRegistry::Global()->RegisterUnaryOpFn(
378 op, device, type_index,
379 [type_index_name, unary_op_fn](OpKernelContext* ctx, const Variant& v,
380 Variant* v_out) -> Status {
381 DCHECK_NE(v_out, nullptr);
382 *v_out = T();
383 if (v.get<T>() == nullptr) {
384 return errors::Internal(
385 "VariantUnaryOpFn: Could not access object, type_index: ",
386 type_index_name);
387 }
388 const T& t = *v.get<T>();
389 T* t_out = v_out->get<T>();
390 return unary_op_fn(ctx, t, t_out);
391 });
392 }
393 };
394
395 template <typename T>
396 class UnaryVariantBinaryOpRegistration {
397 typedef std::function<Status(OpKernelContext* ctx, const T& a, const T& b,
398 T* out)>
399 LocalVariantBinaryOpFn;
400
401 public:
UnaryVariantBinaryOpRegistration(VariantBinaryOp op,const string & device,const TypeIndex & type_index,const LocalVariantBinaryOpFn & binary_op_fn)402 UnaryVariantBinaryOpRegistration(VariantBinaryOp op, const string& device,
403 const TypeIndex& type_index,
404 const LocalVariantBinaryOpFn& binary_op_fn) {
405 const string type_index_name = port::MaybeAbiDemangle(type_index.name());
406 UnaryVariantOpRegistry::Global()->RegisterBinaryOpFn(
407 op, device, type_index,
408 [type_index_name, binary_op_fn](OpKernelContext* ctx, const Variant& a,
409 const Variant& b,
410 Variant* out) -> Status {
411 DCHECK_NE(out, nullptr);
412 *out = T();
413 if (a.get<T>() == nullptr) {
414 return errors::Internal(
415 "VariantBinaryOpFn: Could not access object 'a', type_index: ",
416 type_index_name);
417 }
418 if (b.get<T>() == nullptr) {
419 return errors::Internal(
420 "VariantBinaryOpFn: Could not access object 'b', type_index: ",
421 type_index_name);
422 }
423 const T& t_a = *a.get<T>();
424 const T& t_b = *b.get<T>();
425 T* t_out = out->get<T>();
426 return binary_op_fn(ctx, t_a, t_b, t_out);
427 });
428 }
429 };
430
431 }; // namespace variant_op_registry_fn_registration
432
433 // Register a unary decode variant function for the given type.
434 #define REGISTER_UNARY_VARIANT_DECODE_FUNCTION(T, type_name) \
435 REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ_HELPER(__COUNTER__, T, type_name)
436
437 #define REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ_HELPER(ctr, T, type_name) \
438 REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ(ctr, T, type_name)
439
440 #define REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ(ctr, T, type_name) \
441 static variant_op_registry_fn_registration::UnaryVariantDecodeRegistration< \
442 T> \
443 register_unary_variant_op_decoder_fn_##ctr(type_name)
444
445 // ****** NOTE ******
446 // FOR INTERNAL USE ONLY. IF YOU USE THIS WE MAY BREAK YOUR CODE.
447 // ****** NOTE ******
448 //
449 // Register a device copy variant function for the given copy
450 // direction and type; where direction is the enum
451 // VariantDeviceCopyDirection, and the device_copy_fn has signature:
452 //
453 // Status device_copy_fn(
454 // const T& t, T* t_out,
455 // const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copier);
456 //
457 // And device_copy_fn calls copier 0 or more times. For details on
458 // the behavior of the copier function, see the comments at the
459 // declaration of UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn.
460 //
461 // Note, the device_copy_fn may choose to keep some tensors
462 // on host, e.g. by assigning to->tensor = from.tensor (assuming
463 // from.tensor is already on host); or by setting
464 // to->tensor = Tensor(cpu_allocator(), ...)
465 // and manually updating its values.
466 //
467 // If this is the case, the CopyFns for HOST_TO_DEVICE,
468 // DEVICE_TO_HOST, and DEVICE_TO_DEVICE must perform host-to-host
469 // copies in a consistent manner. For example, one must always
470 // manually copy any "always on host" tensors in all directions instead of e.g.
471 // - performing a host-to-host copy in one direction,
472 // - using the provided copier function in the reverse direction.
473 // Doing the latter will cause program failures.
474 //
475 // ****** NOTE ******
476 // FOR INTERNAL USE ONLY. IF YOU USE THIS WE MAY BREAK YOUR CODE.
477 // ****** NOTE ******
478 #define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(T, direction, \
479 device_copy_fn) \
480 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \
481 __COUNTER__, T, direction, MakeTypeIndex<T>(), device_copy_fn)
482
483 #define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \
484 ctr, T, direction, type_index, device_copy_fn) \
485 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \
486 ctr, T, direction, type_index, device_copy_fn)
487
488 #define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \
489 ctr, T, direction, type_index, device_copy_fn) \
490 static variant_op_registry_fn_registration:: \
491 UnaryVariantDeviceCopyRegistration<T> \
492 register_unary_variant_op_device_copy_fn_##ctr( \
493 direction, type_index, device_copy_fn)
494
495 // Register a unary unary_op variant function with the signature:
496 // Status UnaryOpFn(OpKernelContext* ctx, const T& t, T* t_out);
497 // to Variants having TypeIndex type_index, for device string device,
498 // for UnaryVariantOp enum op.
499 #define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T, \
500 unary_op_function) \
501 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
502 __COUNTER__, op, device, T, MakeTypeIndex<T>(), unary_op_function)
503
504 #define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
505 ctr, op, device, T, type_index, unary_op_function) \
506 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, \
507 type_index, unary_op_function)
508
509 #define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ( \
510 ctr, op, device, T, type_index, unary_op_function) \
511 static variant_op_registry_fn_registration::UnaryVariantUnaryOpRegistration< \
512 T> \
513 register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \
514 unary_op_function)
515
516 // Register a binary_op variant function with the signature:
517 // Status BinaryOpFn(OpKernelContext* ctx, const T& a, const T& b, T* out);
518 // to Variants having TypeIndex type_index, for device string device,
519 // for BinaryVariantOp enum OP.
520 #define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T, \
521 binary_op_function) \
522 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
523 __COUNTER__, op, device, T, MakeTypeIndex<T>(), binary_op_function)
524
525 #define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
526 ctr, op, device, T, type_index, binary_op_function) \
527 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \
528 ctr, op, device, T, type_index, binary_op_function)
529
530 #define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \
531 ctr, op, device, T, type_index, binary_op_function) \
532 static variant_op_registry_fn_registration:: \
533 UnaryVariantBinaryOpRegistration<T> \
534 register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \
535 binary_op_function)
536
537 } // end namespace tensorflow
538
539 #endif // TENSORFLOW_CORE_FRAMEWORK_VARIANT_OP_REGISTRY_H_
540