• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_H_
17 #define TENSORFLOW_CORE_KERNELS_CWISE_OPS_H_
18 
19 #define _USE_MATH_DEFINES
20 #include <cmath>
21 #include <functional>
22 #include <type_traits>
23 
24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25 #include "tensorflow/core/framework/bounds_check.h"
26 #include "tensorflow/core/framework/numeric_types.h"
27 #include "tensorflow/core/framework/tensor_types.h"
28 
29 namespace Eigen {
30 namespace internal {
31 
32 #if GOOGLE_CUDA
33 template <>
34 struct scalar_arg_op<std::complex<float>> {
35   EIGEN_EMPTY_STRUCT_CTOR(scalar_arg_op)
36   typedef typename Eigen::NumTraits<std::complex<float>>::Real result_type;
37   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()(
38       const std::complex<float>& a) const {
39     return ::atan2f(a.imag(), a.real());
40   }
41 };
42 
43 template <>
44 struct scalar_arg_op<std::complex<double>> {
45   EIGEN_EMPTY_STRUCT_CTOR(scalar_arg_op)
46   typedef typename Eigen::NumTraits<std::complex<double>>::Real result_type;
47   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double operator()(
48       const std::complex<double>& a) const {
49     return ::atan2(a.imag(), a.real());
50   }
51 };
52 #endif
53 
54 #if EIGEN_HAS_CXX11_MATH == 0
55 template <typename T>
56 struct scalar_asinh_op {
57   EIGEN_EMPTY_STRUCT_CTOR(scalar_asinh_op)
58   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a) const {
59     return std::asinh(a);
60   }
61 };
62 template <typename T>
63 struct functor_traits<scalar_asinh_op<T>> {
64   enum { Cost = 5 * NumTraits<T>::MulCost, PacketAccess = false };
65 };
66 
67 template <typename T>
68 struct scalar_acosh_op {
69   EIGEN_EMPTY_STRUCT_CTOR(scalar_acosh_op)
70   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a) const {
71     return std::acosh(a);
72   }
73 };
74 template <typename T>
75 struct functor_traits<scalar_acosh_op<T>> {
76   enum { Cost = 5 * NumTraits<T>::MulCost, PacketAccess = false };
77 };
78 
79 template <typename T>
80 struct scalar_atanh_op {
81   EIGEN_EMPTY_STRUCT_CTOR(scalar_atanh_op)
82   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a) const {
83     return std::atanh(a);
84   }
85 };
86 template <typename T>
87 struct functor_traits<scalar_atanh_op<T>> {
88   enum { Cost = 5 * NumTraits<T>::MulCost, PacketAccess = false };
89 };
90 #endif
91 
92 template <typename Scalar, typename Exponent>
93 struct safe_scalar_binary_pow_op {
94   static_assert(std::is_integral<Scalar>::value, "Integer type expected");
95   static_assert(std::is_integral<Exponent>::value &&
96                     std::is_signed<Exponent>::value,
97                 "Signed integer type expected");
98 
99   bool* const error;
100 
101   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_scalar_binary_pow_op(bool* error)
102       : error(error) {}
103 
104   EIGEN_DEVICE_FUNC inline Scalar operator()(const Scalar& a,
105                                              const Exponent& b) const {
106     const Exponent safe_b = tensorflow::internal::SubtleMustCopy(b);
107     if (TF_PREDICT_TRUE(safe_b >= 0)) {
108       return numext::pow(a, safe_b);
109     } else {
110       *error = true;
111       return 0;
112     }
113   }
114 };
115 
116 template <typename Scalar, typename Exponent>
117 struct functor_traits<safe_scalar_binary_pow_op<Scalar, Exponent>> {
118   enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false };
119 };
120 
121 template <typename T, typename DivOrMod>
122 struct safe_div_or_mod_op {
123   static_assert(std::is_integral<T>::value, "Integer type expected");
124 
125   bool* const error;
126 
127   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_div_or_mod_op(bool* error)
128       : error(error) {}
129 
130   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a,
131                                                      const T& b) const {
132     const T safe_b = tensorflow::internal::SubtleMustCopy(b);
133     if (TF_PREDICT_TRUE(safe_b != 0)) {
134       return DivOrMod()(a, safe_b);
135     } else {
136       *error = true;
137       return 0;
138     }
139   }
140 };
141 
142 template <typename T, typename DivOrMod>
143 struct functor_traits<safe_div_or_mod_op<T, DivOrMod>> {
144   enum {
145     Cost = functor_traits<DivOrMod>::Cost + NumTraits<T>::AddCost,
146     PacketAccess = false,
147   };
148 };
149 
150 template <typename T, typename Binary>
151 struct no_nan_op {
152   EIGEN_EMPTY_STRUCT_CTOR(no_nan_op)
153   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a,
154                                                      const T& b) const {
155     if (b != T(0)) {
156       return Binary()(a, b);
157     } else {
158       return T(0);
159     }
160   }
161   template <typename Packet>
162   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a,
163                                                         const Packet& b) const {
164     const Packet mask = pcmp_eq(b, pzero(b));
165     const Packet quotient = Binary().packetOp(a, b);
166     return pandnot(quotient, mask);
167   }
168 };
169 
170 template <typename T>
171 struct div_no_nan_op : public no_nan_op<T, scalar_quotient_op<T>> {
172   EIGEN_EMPTY_STRUCT_CTOR(div_no_nan_op)
173 };
174 
175 template <typename T>
176 struct functor_traits<div_no_nan_op<T>> {
177   enum {
178     Cost = functor_traits<scalar_quotient_op<T>>::Cost + NumTraits<T>::AddCost,
179     PacketAccess = true,
180   };
181 };
182 
183 template <typename T>
184 struct mul_no_nan_op : public no_nan_op<T, scalar_product_op<T>> {
185   EIGEN_EMPTY_STRUCT_CTOR(mul_no_nan_op)
186 };
187 
188 template <typename T>
189 struct functor_traits<mul_no_nan_op<T>> {
190   enum {
191     Cost = functor_traits<scalar_product_op<T>>::Cost + NumTraits<T>::AddCost,
192     PacketAccess = true,
193   };
194 };
195 
196 // scalar_left and scalar_right are template helpers to partially
197 // apply a binary function.
198 //
199 // Suppose Binary is a binary functor f(x, y), scalar_left<> is a
200 // unary functor g_x(y) = f(x, y), where x is provided via the
201 // constructor. Similarly, scalar_right<> is a unary functor g_y(x) =
202 // f(x, y).
203 
204 template <typename Tout, typename Tin, typename Binary,
205           bool is_scalar_in_host_memory = false>
206 struct scalar_left : private Binary {
207   using result_type = Tout;
208   using TinPacket = typename Eigen::internal::packet_traits<Tin>::type;
209 
210   const Tin* left;
211   TinPacket left_packet;  // initialized iff is_scalar_in_host_memory == true
212 
213   EIGEN_DEVICE_FUNC inline scalar_left(const scalar_left& other) = default;
214 
215   template <typename... Args>
216   EIGEN_DEVICE_FUNC inline explicit scalar_left(const Tin* c, Args... args)
217       : Binary(args...), left(c) {
218     if (is_scalar_in_host_memory) {
219       left_packet = Eigen::internal::pset1<TinPacket>(*left);
220     }
221   }
222 
223   EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& right) const {
224     return Binary::operator()(*left, right);
225   }
226 
227   template <typename Packet,
228             typename std::enable_if<!is_scalar_in_host_memory ||
229                                         !std::is_same<TinPacket, Packet>::value,
230                                     int>::type = 0>
231   EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& right_packet) const {
232     const Packet left_packet = Eigen::internal::pset1<Packet>(*left);
233     return Binary::packetOp(left_packet, right_packet);
234   }
235 
236   template <typename Packet,
237             typename std::enable_if<is_scalar_in_host_memory &&
238                                         std::is_same<TinPacket, Packet>::value,
239                                     int>::type = 0>
240   EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& right_packet) const {
241     return Binary::packetOp(left_packet, right_packet);
242   }
243 };
244 
245 template <typename Tout, typename Tin, typename Binary,
246           bool is_scalar_in_host_memory>
247 struct functor_traits<
248     scalar_left<Tout, Tin, Binary, is_scalar_in_host_memory>> {
249   enum {
250     Cost = functor_traits<Binary>::Cost,
251     PacketAccess = functor_traits<Binary>::PacketAccess,
252   };
253 };
254 
255 template <typename Tout, typename Tin, typename Binary,
256           bool is_scalar_in_host_memory = false>
257 struct scalar_right : private Binary {
258   using result_type = Tout;
259   using TinPacket = typename Eigen::internal::packet_traits<Tin>::type;
260 
261   const Tin* right;
262   TinPacket right_packet;  // initialized iff is_scalar_in_host_memory == true
263 
264   EIGEN_DEVICE_FUNC inline scalar_right(const scalar_right& other) = default;
265 
266   template <typename... Args>
267   EIGEN_DEVICE_FUNC inline explicit scalar_right(const Tin* c, Args... args)
268       : Binary(args...), right(c) {
269     if (is_scalar_in_host_memory) {
270       right_packet = Eigen::internal::pset1<TinPacket>(*right);
271     }
272   }
273 
274   EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& left) const {
275     return Binary::operator()(left, *right);
276   }
277 
278   template <typename Packet,
279             typename std::enable_if<!is_scalar_in_host_memory ||
280                                         !std::is_same<TinPacket, Packet>::value,
281                                     int>::type = 0>
282   EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& left_packet) const {
283     const Packet right_packet = Eigen::internal::pset1<Packet>(*right);
284     return Binary::packetOp(left_packet, right_packet);
285   }
286 
287   template <typename Packet,
288             typename std::enable_if<is_scalar_in_host_memory &&
289                                         std::is_same<TinPacket, Packet>::value,
290                                     int>::type = 0>
291   EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& left_packet) const {
292     return Binary::packetOp(left_packet, right_packet);
293   }
294 };
295 
296 template <typename Tout, typename Tin, typename Binary,
297           bool is_scalar_in_host_memory>
298 struct functor_traits<
299     scalar_right<Tout, Tin, Binary, is_scalar_in_host_memory>> {
300   enum {
301     Cost = functor_traits<Binary>::Cost,
302     PacketAccess = functor_traits<Binary>::PacketAccess,
303   };
304 };
305 
306 // similar to std::equal_to, but with the DEVICE_FUNC qualifier
307 template <class T>
308 struct equal_to : std::binary_function<T, T, bool> {
309   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x,
310                                                         const T& y) const {
311     return x == y;
312   }
313 };
314 
315 // similar to std::not_equal_to, but with the DEVICE_FUNC qualifier
316 template <class T>
317 struct not_equal_to : std::binary_function<T, T, bool> {
318   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x,
319                                                         const T& y) const {
320     return x != y;
321   }
322 };
323 
324 // similar to std::greater, but with the DEVICE_FUNC qualifier
325 template <class T>
326 struct greater : std::binary_function<T, T, bool> {
327   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x,
328                                                         const T& y) const {
329     return x > y;
330   }
331 };
332 
333 // similar to std::less, but with the DEVICE_FUNC qualifier
334 template <class T>
335 struct less : std::binary_function<T, T, bool> {
336   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x,
337                                                         const T& y) const {
338     return x < y;
339   }
340 };
341 
342 // similar to std::greater_equal, but with the DEVICE_FUNC qualifier
343 template <class T>
344 struct greater_equal : std::binary_function<T, T, bool> {
345   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x,
346                                                         const T& y) const {
347     return x >= y;
348   }
349 };
350 
351 // similar to std::less_equal, but with the DEVICE_FUNC qualifier
352 template <class T>
353 struct less_equal : std::binary_function<T, T, bool> {
354   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x,
355                                                         const T& y) const {
356     return x <= y;
357   }
358 };
359 
360 // Functor that enables squared difference functor.
361 template <typename Scalar>
362 struct scalar_squared_difference_op {
363   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
364   operator()(const Scalar& a, const Scalar& b) const {
365     const Scalar v = scalar_difference_op<Scalar>()(a, b);
366     return scalar_product_op<Scalar>()(v, scalar_conjugate_op<Scalar>()(v));
367   }
368   template <typename Packet>
369   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a,
370                                                         const Packet& b) const {
371     const Packet v = scalar_difference_op<Scalar>().packetOp(a, b);
372     return scalar_product_op<Scalar>().packetOp(
373         v, scalar_conjugate_op<Scalar>().packetOp(v));
374   }
375 };
376 
377 template <typename Scalar>
378 struct functor_traits<scalar_squared_difference_op<Scalar>> {
379   enum {
380     Cost = functor_traits<scalar_difference_op<Scalar>>::Cost +
381            functor_traits<scalar_conjugate_op<Scalar>>::Cost +
382            functor_traits<scalar_product_op<Scalar>>::Cost,
383     PacketAccess = functor_traits<scalar_difference_op<Scalar>>::PacketAccess &&
384                    functor_traits<scalar_conjugate_op<Scalar>>::PacketAccess &&
385                    functor_traits<scalar_product_op<Scalar>>::PacketAccess
386   };
387 };
388 
389 // TODO(b/32239616): This kernel should be moved into Eigen and vectorized.
390 template <typename T, typename Enable = void>
391 struct google_floor_div {
392   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x,
393                                                      const T& y) const {
394     if ((x < T(0)) != (y < T(0))) {
395 // HIP does not have the device version of the abs routine defined
396 // for all datatypes that T can resolve to
397 #if defined(__HIP_DEVICE_COMPILE__)
398       T abs_x = (x < T(0)) ? -x : x;
399       T abs_y = (y < T(0)) ? -y : y;
400 #else
401       T abs_x = std::abs(x);
402       T abs_y = std::abs(y);
403 #endif
404       return -(abs_x + abs_y - 1) / abs_y;
405     } else {
406       return x / y;
407     }
408   }
409   template <typename Packet>
410   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x,
411                                                         const Packet& y) const {
412     Packet zeros = pzero(x);
413     Packet x_mask = pcmp_lt(x, zeros);
414     Packet y_mask = pcmp_lt(y, zeros);
415     Packet x_div_y = pdiv(x, y);
416     Packet abs_x = pabs(x);
417     Packet abs_y = pabs(y);
418     Packet ones = pones(x);
419     Packet ratio_rounded = pdiv(pnegate(psub(padd(abs_x, abs_y), ones)), abs_y);
420     return pselect(pxor(x_mask, y_mask), ratio_rounded, x_div_y);
421   }
422 };
423 
424 template <typename T>
425 struct google_floor_div<
426     T, typename std::enable_if<std::is_unsigned<T>::value>::type> {
427   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x,
428                                                      const T& y) const {
429     return x / y;
430   }
431   template <typename Packet>
432   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x,
433                                                         const Packet& y) const {
434     return pdiv(x, y);
435   }
436 };
437 
438 template <typename Scalar>
439 struct functor_traits<google_floor_div<Scalar>> {
440   enum {
441     Cost = 2 * Eigen::internal::scalar_div_cost<
442                    Scalar, packet_traits<Scalar>::HasDiv>::value +
443            NumTraits<Scalar>::AddCost,
444     PacketAccess = packet_traits<Scalar>::HasDiv
445   };
446 };
447 
448 template <typename T, typename Enable = void>
449 struct google_floor_div_real {
450   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x,
451                                                      const T& y) const {
452     return Eigen::numext::floor(x / y);
453   }
454   template <typename Packet>
455   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x,
456                                                         const Packet& y) const {
457     return pfloor(pdiv(x, y));
458   }
459 };
460 
461 template <typename Scalar>
462 struct functor_traits<google_floor_div_real<Scalar>> {
463   enum {
464     Cost = 2 * Eigen::internal::scalar_div_cost<
465                    Scalar, packet_traits<Scalar>::HasDiv>::value +
466            2 * NumTraits<Scalar>::AddCost,
467     PacketAccess =
468         packet_traits<Scalar>::HasDiv && packet_traits<Scalar>::HasFloor
469   };
470 };
471 
472 // TODO(rmlarsen): Add vectorized mod & fmod in Eigen and use it here.
473 template <typename T>
474 struct google_floor_fmod {
475   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x,
476                                                      const T& y) const {
477     // EIGEN_STATIC_ASSERT(NUMERIC_TYPE_MUST_BE_REAL);
478     T trunc_mod = scalar_fmod_op<T>()(x, y);
479     return trunc_mod != T(0) && (y < T(0) != trunc_mod < T(0)) ? trunc_mod + y
480                                                                : trunc_mod;
481   }
482 };
483 
484 template <typename Scalar>
485 struct functor_traits<google_floor_fmod<Scalar>> {
486   enum {
487     Cost = functor_traits<Eigen::internal::scalar_fmod_op<Scalar>>::Cost +
488            NumTraits<Scalar>::AddCost,
489     PacketAccess = false
490   };
491 };
492 
493 // TODO(rmlarsen): Add vectorized mod & fmod in Eigen and use it here.
494 template <typename T>
495 struct google_floor_mod {
496   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x,
497                                                      const T& y) const {
498     // EIGEN_STATIC_ASSERT(!NUMERIC_TYPE_MUST_BE_REAL);
499     T trunc_mod = Eigen::internal::scalar_mod2_op<T>()(x, y);
500     return trunc_mod != T(0) && (y < T(0) != trunc_mod < T(0)) ? trunc_mod + y
501                                                                : trunc_mod;
502   }
503 };
504 
505 template <typename Scalar>
506 struct functor_traits<google_floor_mod<Scalar>> {
507   enum {
508     Cost = functor_traits<Eigen::internal::scalar_mod2_op<Scalar>>::Cost +
509            NumTraits<Scalar>::AddCost,
510     PacketAccess = false
511   };
512 };
513 
514 #if EIGEN_COMP_GNUC && __cplusplus > 199711L
515 #define DISABLE_FLOAT_EQUALITY_WARNING \
516   _Pragma("GCC diagnostic push")       \
517       _Pragma("GCC diagnostic ignored \"-Wfloat-equal\"")
518 #define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop")
519 #else
520 #define DISABLE_FLOAT_EQUALITY_WARNING
521 #define ENABLE_FLOAT_EQUALITY_WARNING
522 #endif
523 
524 template <typename Scalar, bool IsInteger = Eigen::NumTraits<Scalar>::IsInteger,
525           bool HasRint = packet_traits<Scalar>::HasRint>
526 struct scalar_round_half_to_even_op {
527   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
528   operator()(const Scalar& x) const {
529     EIGEN_STATIC_ASSERT((!NumTraits<Scalar>::IsComplex),
530                         NUMERIC_TYPE_MUST_BE_REAL)
531 
532     const Scalar round_val = Eigen::numext::floor(x + Scalar(0.5));
533     const Scalar fraction = round_val - x;
534     if (TF_PREDICT_FALSE(fraction == Scalar(.5))) {
535       return Scalar(2) * Eigen::numext::floor(Scalar(.5) * x + Scalar(0.5));
536     } else {
537       return round_val;
538     }
539   }
540 
541   template <typename Packet>
542   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const {
543     Packet half = pset1<Packet>(Scalar(0.5));
544     Packet round_val = pfloor(padd(x, half));
545     Packet fraction = psub(round_val, x);
546     Packet half_mask = pcmp_eq(fraction, half);
547     bool any_halves = predux_any(half_mask);
548     if (TF_PREDICT_FALSE(any_halves)) {
549       Packet two = pset1<Packet>(Scalar(2));
550       Packet nearest_even = pmul(two, pfloor(pmadd(half, x, half)));
551       return pselect(half_mask, nearest_even, round_val);
552     } else {
553       return round_val;
554     }
555   }
556 };
557 
558 template <typename Scalar>
559 struct scalar_round_half_to_even_op<Scalar, true, false> {
560   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
561   operator()(const Scalar& x) const {
562     return x;
563   }
564   template <typename Packet>
565   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const {
566     return x;
567   }
568 };
569 
570 template <typename Scalar>
571 struct scalar_round_half_to_even_op<Scalar, false, true> {
572   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
573   operator()(const Scalar& x) const {
574     return Eigen::numext::rint(x);
575   }
576   template <typename Packet>
577   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const {
578     return print(x);
579   }
580 };
581 
582 template <typename Scalar>
583 struct functor_traits<scalar_round_half_to_even_op<Scalar>> {
584   enum {
585     Cost = Eigen::NumTraits<Scalar>::IsInteger ? 0
586                                                : 4 * NumTraits<Scalar>::AddCost,
587     PacketAccess = packet_traits<Scalar>::HasFloor &&
588                    packet_traits<Scalar>::HasAdd &&
589                    packet_traits<Scalar>::HasMul,
590   };
591 };
592 
593 template <typename Scalar, bool IsInteger = Eigen::NumTraits<Scalar>::IsInteger>
594 struct scalar_round_up_op {
595   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
596   operator()(const Scalar& x) const {
597     EIGEN_STATIC_ASSERT((!NumTraits<Scalar>::IsComplex),
598                         NUMERIC_TYPE_MUST_BE_REAL)
599     return Eigen::numext::floor(x + Scalar(0.5));
600   }
601 
602   template <typename Packet>
603   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const {
604     return pfloor(padd(x, pset1<Packet>(0.5)));
605   }
606 };
607 
608 template <typename Scalar>
609 struct scalar_round_up_op<Scalar, true> {
610   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
611   operator()(const Scalar& x) const {
612     return x;
613   }
614 
615   template <typename Packet>
616   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const {
617     return x;
618   }
619 };
620 
621 template <typename Scalar, bool IsInteger>
622 struct functor_traits<scalar_round_up_op<Scalar, IsInteger>> {
623   enum {
624     Cost = IsInteger ? 0 : 4 * NumTraits<Scalar>::AddCost,
625     PacketAccess = IsInteger || packet_traits<Scalar>::HasFloor
626   };
627 };
628 
629 #undef ENABLE_FLOAT_EQUALITY_WARNING
630 #undef DISABLE_FLOAT_EQUALITY_WARNING
631 
632 template <typename Scalar>
633 struct bitwise_xor_op {
634   EIGEN_EMPTY_STRUCT_CTOR(bitwise_xor_op)
635   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
636   operator()(const Scalar& x, const Scalar& y) const {
637     return x ^ y;
638   }
639   template <typename Packet>
640   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a,
641                                                         const Packet& b) const {
642     return Eigen::internal::pxor(a, b);
643   }
644 };
645 
646 template <typename Scalar>
647 struct functor_traits<bitwise_xor_op<Scalar>> {
648   enum { Cost = Eigen::NumTraits<Scalar>::AddCost, PacketAccess = true };
649 };
650 
651 template <typename Scalar>
652 struct xlogy_op {
653   EIGEN_EMPTY_STRUCT_CTOR(xlogy_op)
654   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
655   operator()(const Scalar& x, const Scalar& y) const {
656     if (x == Scalar(0.)) {
657       return Scalar(0.);
658     }
659     return x * numext::log(y);
660   }
661   template <typename Packet>
662   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x,
663                                                         const Packet& y) const {
664     Packet zeros = pzero(x);
665     Packet mask = pcmp_eq(x, zeros);
666     scalar_log_op<Scalar> log_op;
667     Packet log_y = log_op.packetOp(y);
668     Packet x_log_y = pmul(x, log_y);
669     return pselect(mask, x, x_log_y);
670   }
671 };
672 
673 template <typename Scalar>
674 struct functor_traits<xlogy_op<Scalar>> {
675   enum {
676     Cost = functor_traits<scalar_log_op<Scalar>>::Cost +
677            Eigen::NumTraits<Scalar>::MulCost,
678     PacketAccess = functor_traits<scalar_log_op<Scalar>>::PacketAccess
679   };
680 };
681 
682 template <typename Scalar>
683 struct xlog1py_op {
684   EIGEN_EMPTY_STRUCT_CTOR(xlog1py_op)
685   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
686   operator()(const Scalar& x, const Scalar& y) const {
687     if (x == Scalar(0.)) {
688       return Scalar(0.);
689     }
690     return x * numext::log1p(y);
691   }
692   template <typename Packet>
693   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x,
694                                                         const Packet& y) const {
695     Packet zeros = pzero(x);
696     Packet mask = pcmp_eq(x, zeros);
697     scalar_log1p_op<Scalar> log1p_op;
698     Packet log1p_y = log1p_op.packetOp(y);
699     Packet x_log1p_y = pmul(x, log1p_y);
700     return pselect(mask, x, x_log1p_y);
701   }
702 };
703 
704 template <typename Scalar>
705 struct functor_traits<xlog1py_op<Scalar>> {
706   enum {
707     Cost = functor_traits<scalar_log1p_op<Scalar>>::Cost +
708            Eigen::NumTraits<Scalar>::MulCost,
709 #if TENSORFLOW_USE_ROCM
710     PacketAccess = false,
711 #else
712     PacketAccess = functor_traits<scalar_log1p_op<Scalar>>::PacketAccess
713 #endif
714   };
715 };
716 
717 template <typename Scalar>
718 struct xdivy_op {
719   EIGEN_EMPTY_STRUCT_CTOR(xdivy_op)
720   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
721   operator()(const Scalar& x, const Scalar& y) const {
722     if (x == Scalar(0.)) {
723       return Scalar(0.);
724     }
725     return x / y;
726   }
727   template <typename Packet>
728   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x,
729                                                         const Packet& y) const {
730     Packet zeros = pzero(x);
731     Packet mask = pcmp_eq(x, zeros);
732     Packet x_div_y = pdiv(x, y);
733     return pselect(mask, x, x_div_y);
734   }
735 };
736 
737 template <typename Scalar>
738 struct functor_traits<xdivy_op<Scalar>> {
739   enum {
740     Cost =
741         Eigen::NumTraits<Scalar>::AddCost +
742         Eigen::internal::scalar_div_cost<Scalar,
743                                          packet_traits<Scalar>::HasDiv>::value,
744     PacketAccess = packet_traits<Scalar>::HasDiv
745   };
746 };
747 
748 template <typename T>
749 struct scalar_erfinv_op {
750   EIGEN_EMPTY_STRUCT_CTOR(scalar_erfinv_op)
751   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const {
752     constexpr T half = T(0.5);
753     T y = numext::ndtri(half * x + half);
754     constexpr T half_sqrt = T(M_SQRT1_2);
755     return y * half_sqrt;
756   }
757   template <typename Packet>
758   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const {
759     Packet half = pset1<Packet>(T(0.5));
760     Packet y = pndtri<Packet>(pmadd(half, x, half));
761     Packet half_sqrt = pset1<Packet>(T(M_SQRT1_2));
762     return pmul(y, half_sqrt);
763   }
764 };
765 
766 template <typename T>
767 struct functor_traits<scalar_erfinv_op<T>> {
768   enum {
769     Cost = functor_traits<scalar_ndtri_op<T>>::Cost + NumTraits<T>::AddCost,
770     PacketAccess = packet_traits<T>::HasNdtri,
771   };
772 };
773 
774 }  // end namespace internal
775 }  // end namespace Eigen
776 
777 namespace tensorflow {
778 namespace functor {
779 
780 ////////////////////////////////////////////////////////////////////////////////
781 // Helpers
782 ////////////////////////////////////////////////////////////////////////////////
783 
784 // Base template for functors whose input scalar type is T and
785 // output scalar type is R.
786 template <typename T, typename F, typename R = T>
787 struct base {
788   // func defines operator() and its vectorized version packetOp().
789   typedef F func;
790 
791   // If true, the functor's corresponding binary op will instantiate
792   // specialized kernels to perform an optimized broadcast
793   // operation. Each functor for which this is enabled increases the
794   // code size, so by default this is disabled for binary functors and
795   // is enabled on a per-op basis as needed.
796   static constexpr bool use_bcast_optimization = false;
797 
798   // operator() has the signature:
799   //  out_type operator()(in_type in0, in_type in1 ...)
800   typedef R out_type;
801   typedef T in_type;
802 
803   // TensorFlow provides tensor-ized version of "func". Roughly
804   // speaking, the tensorflow operation has the signature:
805   //   tout_type op(tin_type in0)
806   //   tout_type op(tin_type in0, tin_type in1)
807   //   tout_type op(tin_type in0, in_type scalar)
808   typedef typename TTypes<out_type>::Flat tout_type;
809   typedef typename TTypes<in_type>::ConstFlat tin_type;
810   typedef typename TTypes<in_type>::ConstScalar tscalar_type;
811 
812   // Whether the functor can error out.  Currently applies only to integer
813   // div and mod.
814   static constexpr bool has_errors = false;
815 };
816 
817 // For now, we only apply certain speed optimization for
818 // float/double's broadcast binary op.
819 template <typename T>
820 struct use_bcast_optimization {
821   static constexpr bool value = false;
822 };
823 
824 template <>
825 struct use_bcast_optimization<float> {
826   static constexpr bool value = true;
827 };
828 
829 template <>
830 struct use_bcast_optimization<double> {
831   static constexpr bool value = true;
832 };
833 
834 ////////////////////////////////////////////////////////////////////////////////
835 // Unary functors
836 ////////////////////////////////////////////////////////////////////////////////
837 
838 // abs(x) = |x|
839 // neg(x) = - x
840 // inverse(x) = 1 / x
841 // square(x) = x^2
842 // sqrt(x) = x^(1/2)
843 // rsqrt(x) = x^(-1/2)
844 // exp(x) = e^x
845 // expm1(x) = e^x - 1
846 // log(x) = natural logarithm of x
847 // log1p(x) = natural logarithm of 1 + x
848 // tanh = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
849 // sigmoid = 1 / (1 + exp(-x))  // a.k.a, logistic
850 //
851 // NOTE: We may eventually implement common functions used in NN
852 // here. E.g., rectifier, softplus, derivatives of tanh, sigmod, etc.
853 // For reference, see speech/lstm/eigen_functors.h.
854 
855 template <typename T>
856 struct abs : base<T, Eigen::internal::scalar_abs_op<T>,
857                   typename Eigen::internal::scalar_abs_op<T>::result_type> {};
858 
859 template <typename T>
860 struct neg : base<T, Eigen::internal::scalar_opposite_op<T>> {};
861 
862 template <typename T>
863 struct inverse : base<T, Eigen::internal::scalar_inverse_op<T>> {};
864 
865 template <typename T>
866 struct square : base<T, Eigen::internal::scalar_square_op<T>> {};
867 
868 template <typename T>
869 struct sqrt : base<T, Eigen::internal::scalar_sqrt_op<T>> {};
870 
871 template <typename T>
872 struct rsqrt : base<T, Eigen::internal::scalar_rsqrt_op<T>> {};
873 
874 template <typename T>
875 struct exp : base<T, Eigen::internal::scalar_exp_op<T>> {};
876 
877 template <typename T>
878 struct expm1 : base<T, Eigen::internal::scalar_expm1_op<T>> {};
879 
880 template <typename T>
881 struct log : base<T, Eigen::internal::scalar_log_op<T>> {};
882 
883 template <typename T>
884 struct log1p : base<T, Eigen::internal::scalar_log1p_op<T>> {};
885 
886 template <typename T>
887 struct sign : base<T, Eigen::internal::scalar_sign_op<T>> {};
888 
889 template <typename T>
890 struct sinh : base<T, Eigen::internal::scalar_sinh_op<T>> {};
891 
892 template <typename T>
893 struct cosh : base<T, Eigen::internal::scalar_cosh_op<T>> {};
894 
895 template <typename T>
896 struct tanh : base<T, Eigen::internal::scalar_tanh_op<T>> {};
897 
898 template <typename T>
899 struct asinh : base<T, Eigen::internal::scalar_asinh_op<T>> {};
900 
901 template <typename T>
902 struct acosh : base<T, Eigen::internal::scalar_acosh_op<T>> {};
903 
904 template <typename T>
905 struct atanh : base<T, Eigen::internal::scalar_atanh_op<T>> {};
906 
907 template <typename T>
908 struct lgamma : base<T, Eigen::internal::scalar_lgamma_op<T>> {};
909 
910 template <typename T>
911 struct digamma : base<T, Eigen::internal::scalar_digamma_op<T>> {};
912 
913 template <typename T>
914 struct erf : base<T, Eigen::internal::scalar_erf_op<T>> {};
915 
916 template <typename T>
917 struct erfc : base<T, Eigen::internal::scalar_erfc_op<T>> {};
918 
919 template <typename T>
920 struct ndtri : base<T, Eigen::internal::scalar_ndtri_op<T>> {};
921 
922 template <typename T>
923 struct erfinv : base<T, Eigen::internal::scalar_erfinv_op<T>> {};
924 
925 template <typename T>
926 struct sigmoid : base<T, Eigen::internal::scalar_logistic_op<T>> {};
927 
928 template <typename T>
929 struct sin : base<T, Eigen::internal::scalar_sin_op<T>> {};
930 
931 template <typename T>
932 struct cos : base<T, Eigen::internal::scalar_cos_op<T>> {};
933 
934 template <typename T>
935 struct tan : base<T, Eigen::internal::scalar_tan_op<T>> {};
936 
937 template <typename T>
938 struct asin : base<T, Eigen::internal::scalar_asin_op<T>> {};
939 
940 template <typename T>
941 struct acos : base<T, Eigen::internal::scalar_acos_op<T>> {};
942 
943 template <typename T>
944 struct atan : base<T, Eigen::internal::scalar_atan_op<T>> {};
945 
946 struct logical_not : base<bool, Eigen::internal::scalar_boolean_not_op<bool>> {
947 };
948 
949 // Flip all bits. Named invert to be consistent with numpy.
950 template <typename T>
951 struct invert_op {
952   EIGEN_EMPTY_STRUCT_CTOR(invert_op)
953   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a) const {
954     return ~a;
955   }
956 };
957 
958 template <typename T>
959 struct invert : base<T, invert_op<T>> {};
960 
961 // NOTE: std::isinf, std::isnan, std::isfinite are plain function.
962 // Therefore we need to wrap them in functors to be used with Eigen's
963 // type system.
964 template <typename T>
965 struct isinf : base<T, Eigen::internal::scalar_isinf_op<T>, bool> {};
966 
967 template <typename T>
968 struct isnan : base<T, Eigen::internal::scalar_isnan_op<T>, bool> {};
969 
970 template <typename T>
971 struct isfinite : base<T, Eigen::internal::scalar_isfinite_op<T>, bool> {};
972 
973 template <typename T>
974 struct floor : base<T, Eigen::internal::scalar_floor_op<T>> {};
975 
976 template <typename T>
977 struct round : base<T, Eigen::internal::scalar_round_half_to_even_op<T>> {};
978 
979 template <typename T>
980 struct ceil : base<T, Eigen::internal::scalar_ceil_op<T>> {};
981 
982 // Note: rint rounds half values to even, just like round_half_to_even_op.
983 template <typename T>
984 struct rint : base<T, Eigen::internal::scalar_rint_op<T>> {};
985 
986 ////////////////////////////////////////////////////////////////////////////////
987 // Binary functors
988 ////////////////////////////////////////////////////////////////////////////////
989 
990 // Binary functors:
991 //
992 // add(x, y) = x + y
993 // sub(x, y) = x - y
994 // mul(x, y) = x * y
995 // div(x, y) = x / y
996 // mod(x, y) = x % y         (int32 and int64 only)
997 // fmod(x, y) = fmod(x, y)   (float and double only)
998 // pow(x, y) = x ^ y
999 // maximum(x, y) = x > y ? x : y
1000 // minimum(x, y) = x < y ? x : y
1001 // squared_difference(x, y) = conj(x - y) * (x - y)
1002 
1003 template <typename T>
1004 struct add : base<T, Eigen::internal::scalar_sum_op<T>> {
1005   static constexpr bool use_bcast_optimization = true;
1006 };
1007 
1008 template <typename T>
1009 struct sub : base<T, Eigen::internal::scalar_difference_op<T>> {
1010   static constexpr bool use_bcast_optimization = true;
1011 };
1012 
1013 template <typename T>
1014 struct mul : base<T, Eigen::internal::scalar_product_op<T>> {
1015   static constexpr bool use_bcast_optimization = true;
1016 };
1017 
1018 template <typename T>
1019 struct mul_no_nan : base<T, Eigen::internal::mul_no_nan_op<T>> {};
1020 
1021 template <typename T>
1022 struct div : base<T, Eigen::internal::scalar_quotient_op<T>> {};
1023 
1024 template <typename T>
1025 struct safe_div : base<T, Eigen::internal::safe_div_or_mod_op<
1026                               T, Eigen::internal::scalar_quotient_op<T>>> {
1027   static constexpr bool has_errors = true;
1028 };
1029 
1030 template <typename T>
1031 struct div_no_nan : base<T, Eigen::internal::div_no_nan_op<T>> {};
1032 
1033 template <typename T>
1034 struct fmod : base<T, Eigen::internal::scalar_fmod_op<T>> {};
1035 
1036 template <typename T>
1037 struct mod : base<T, Eigen::internal::scalar_mod2_op<T>> {};
1038 
1039 template <typename T>
1040 struct safe_mod : base<T, Eigen::internal::safe_div_or_mod_op<
1041                               T, Eigen::internal::scalar_mod2_op<T>>> {
1042   static constexpr bool has_errors = true;
1043 };
1044 
1045 template <typename T>
1046 struct floor_fmod : base<T, Eigen::internal::google_floor_fmod<T>> {};
1047 
1048 template <typename T>
1049 struct safe_floor_mod : base<T, Eigen::internal::safe_div_or_mod_op<
1050                                     T, Eigen::internal::google_floor_mod<T>>> {
1051   static constexpr bool has_errors = true;
1052 };
1053 
1054 template <typename T>
1055 struct floor_div : base<T, Eigen::internal::google_floor_div<T>> {};
1056 
1057 template <typename T>
1058 struct safe_floor_div : base<T, Eigen::internal::safe_div_or_mod_op<
1059                                     T, Eigen::internal::google_floor_div<T>>> {
1060   static constexpr bool has_errors = true;
1061 };
1062 
1063 template <typename T>
1064 struct floor_div_real : base<T, Eigen::internal::google_floor_div_real<T>> {};
1065 
1066 template <typename T>
1067 struct pow : base<T, Eigen::internal::scalar_pow_op<T, T>> {};
1068 
1069 template <typename T>
1070 struct safe_pow : base<T, Eigen::internal::safe_scalar_binary_pow_op<T, T>> {
1071   static constexpr bool has_errors = true;
1072 };
1073 
1074 // Version of safe_pow for integers which returns 0 if RHS is negative and LHS
1075 // is not 1 or -1. For use on GPUs, where we cannot raise an error.
1076 template <typename T>
1077 struct safe_pow_ignore_error_op {
1078   static_assert(std::is_integral<T>::value, "Integer type expected");
1079   EIGEN_EMPTY_STRUCT_CTOR(safe_pow_ignore_error_op)
1080   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x,
1081                                                      const T& y) const {
1082     if (TF_PREDICT_FALSE(y < 0)) {
1083       if (x == T(-1)) {
1084         T trunc_mod = Eigen::internal::scalar_mod2_op<T>()(y, T(2));
1085         return trunc_mod == T(-1) ? T(-1) : T(1);
1086       }
1087       return x == T(1) ? T(1) : T(0);
1088     }
1089     return Eigen::internal::scalar_pow_op<T, T>{}(x, y);
1090   }
1091 };
1092 
1093 template <typename T>
1094 struct safe_pow_ignore_error : base<T, safe_pow_ignore_error_op<T>> {};
1095 
1096 template <typename T>
1097 struct maximum
1098     : base<T, Eigen::internal::scalar_max_op<T, T, Eigen::PropagateNaN>> {};
1099 
1100 template <typename T>
1101 struct minimum
1102     : base<T, Eigen::internal::scalar_min_op<T, T, Eigen::PropagateNaN>> {};
1103 
1104 template <typename T>
1105 struct igamma : base<T, Eigen::internal::scalar_igamma_op<T>> {};
1106 
1107 template <typename T>
1108 struct random_gamma_grad
1109     : base<T, Eigen::internal::scalar_gamma_sample_der_alpha_op<T>> {};
1110 
1111 template <typename T>
1112 struct igammac : base<T, Eigen::internal::scalar_igammac_op<T>> {};
1113 
1114 template <typename T>
1115 struct zeta : base<T, Eigen::internal::scalar_zeta_op<T>> {};
1116 
1117 template <typename T>
1118 struct polygamma : base<T, Eigen::internal::scalar_polygamma_op<T>> {};
1119 
1120 template <typename Scalar>
1121 struct scalar_atan2_op {
1122   EIGEN_EMPTY_STRUCT_CTOR(scalar_atan2_op)
1123   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
1124   operator()(const Scalar& y, const Scalar& x) const {
1125 #if TENSORFLOW_USE_ROCM
1126     return ::atan2(y, x);
1127 #else
1128     return std::atan2(y, x);
1129 #endif
1130   }
1131 };
1132 
1133 template <typename T>
1134 struct atan2 : base<T, scalar_atan2_op<T>> {};
1135 
1136 template <typename T>
1137 struct squared_difference
1138     : base<T, Eigen::internal::scalar_squared_difference_op<T>> {};
1139 
1140 template <typename T>
1141 struct xdivy : base<T, Eigen::internal::xdivy_op<T>> {};
1142 
1143 template <typename T>
1144 struct xlogy : base<T, Eigen::internal::xlogy_op<T>> {};
1145 
1146 template <typename T>
1147 struct xlog1py : base<T, Eigen::internal::xlog1py_op<T>> {};
1148 
1149 template <typename T>
1150 struct less : base<T, Eigen::internal::less<T>, bool> {};
1151 
1152 template <typename T>
1153 struct less_equal : base<T, Eigen::internal::less_equal<T>, bool> {};
1154 
1155 template <typename T>
1156 struct greater : base<T, Eigen::internal::greater<T>, bool> {};
1157 
1158 template <typename T>
1159 struct greater_equal : base<T, Eigen::internal::greater_equal<T>, bool> {};
1160 
1161 template <typename T>
1162 struct equal_to : base<T, Eigen::internal::equal_to<T>, bool> {};
1163 
1164 template <typename T>
1165 struct not_equal_to : base<T, Eigen::internal::not_equal_to<T>, bool> {};
1166 
1167 struct logical_and : base<bool, Eigen::internal::scalar_boolean_and_op> {};
1168 
1169 struct logical_or : base<bool, Eigen::internal::scalar_boolean_or_op> {};
1170 
1171 template <typename T>
1172 struct bitwise_and_op {
1173   EIGEN_EMPTY_STRUCT_CTOR(bitwise_and_op)
1174   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x,
1175                                                      const T& y) const {
1176     return x & y;
1177   }
1178 };
1179 
1180 template <typename T>
1181 struct bitwise_or_op {
1182   EIGEN_EMPTY_STRUCT_CTOR(bitwise_or_op)
1183   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x,
1184                                                      const T& y) const {
1185     return x | y;
1186   }
1187 };
1188 
1189 template <typename T>
1190 struct bitwise_and : base<T, bitwise_and_op<T>> {};
1191 
1192 template <typename T>
1193 struct bitwise_or : base<T, bitwise_or_op<T>> {};
1194 
1195 template <typename T>
1196 struct bitwise_xor : base<T, Eigen::internal::bitwise_xor_op<T>> {};
1197 
1198 template <typename T>
1199 struct left_shift_op {
1200   EIGEN_EMPTY_STRUCT_CTOR(left_shift_op)
1201   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x,
1202                                                      const T& y) const {
1203     // Avoids UB: don't shift by larger than the bitwidth of T, and
1204     // performs left shifts as unsigned shifts.
1205     T y_clamped = y;
1206     if (y_clamped < 0) {
1207       y_clamped = 0;
1208     } else if (y_clamped > sizeof(T) * CHAR_BIT - 1) {
1209       y_clamped = sizeof(T) * CHAR_BIT - 1;
1210     }
1211     using U = typename std::make_unsigned<T>::type;
1212     return static_cast<T>(static_cast<U>(x) << static_cast<U>(y_clamped));
1213   }
1214 };
1215 
1216 template <typename T>
1217 struct right_shift_op {
1218   EIGEN_EMPTY_STRUCT_CTOR(right_shift_op)
1219   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x,
1220                                                      const T& y) const {
1221     // Avoids UB: don't shift by larger than the bitwidth of T.
1222     T y_clamped = y;
1223     if (y_clamped < 0) {
1224       y_clamped = 0;
1225     } else if (y_clamped > sizeof(T) * CHAR_BIT - 1) {
1226       y_clamped = sizeof(T) * CHAR_BIT - 1;
1227     }
1228     // Technically right shifts of signed integers are not necessarily
1229     // arithmetic shifts according to the C++ standard. However in practice most
1230     // implementations are arithmetic shifts. If this proves to be a problem in
1231     // practice, we may need to use an alternative implementation.
1232     return x >> y_clamped;
1233   }
1234 };
1235 
1236 template <typename T>
1237 struct left_shift : base<T, left_shift_op<T>> {};
1238 
1239 template <typename T>
1240 struct right_shift : base<T, right_shift_op<T>> {};
1241 
1242 template <typename T>
1243 struct make_complex_func {
1244   typedef std::complex<T> result_type;
1245   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(T real,
1246                                                                T imag) const {
1247     return std::complex<T>(real, imag);
1248   }
1249 };
1250 
1251 template <typename T>
1252 struct make_complex : base<T, make_complex_func<T>, std::complex<T>> {};
1253 
1254 template <typename T>
1255 struct get_real
1256     : base<T, Eigen::internal::scalar_real_op<T>, typename T::value_type> {};
1257 
1258 template <typename T>
1259 struct get_imag
1260     : base<T, Eigen::internal::scalar_imag_op<T>, typename T::value_type> {};
1261 
1262 template <typename T>
1263 struct get_angle
1264     : base<T, Eigen::internal::scalar_arg_op<T>, typename T::value_type> {};
1265 
1266 template <typename T>
1267 struct conj : base<T, Eigen::internal::scalar_conjugate_op<T>> {};
1268 
1269 ////////////////////////////////////////////////////////////////////////////////
1270 // Functors takes 1 or 2 tensors, computes the base functor on
1271 // coefficient of the input tensors and puts the results in the output
1272 // tensor.
1273 ////////////////////////////////////////////////////////////////////////////////
1274 template <typename Device, typename Functor>
1275 struct UnaryFunctor {
1276   // Computes on device "d": out[i] = Functor(in[i])
1277   void operator()(const Device& d, typename Functor::tout_type out,
1278                   typename Functor::tin_type in);
1279 };
1280 
1281 template <typename Device, typename Functor, int NDIMS,
1282           bool has_errors = Functor::has_errors>
1283 struct BinaryFunctor {
1284   // Computes on device "d": out[i] = Functor(in0[i], in1[i])
1285   void operator()(const Device& d, typename Functor::tout_type out,
1286                   typename Functor::tin_type in0,
1287                   typename Functor::tin_type in1, bool* error);
1288 
1289   // Computes on device "d": out[i] = Functor(scalar[0], in[i])
1290   void Left(const Device& d, typename Functor::tout_type out,
1291             typename Functor::tscalar_type scalar,
1292             typename Functor::tin_type in, bool* error);
1293 
1294   // Computes on device "d": out[i] = Functor(in[i], scalar[0])
1295   void Right(const Device& d, typename Functor::tout_type out,
1296              typename Functor::tin_type in,
1297              typename Functor::tscalar_type scalar, bool* error);
1298 
1299   // Computes on device "d":
1300   //   out = Functor(in0.broadcast(bcast0), in1.broadcast(bcast1))
1301   //
1302   // TODO(zhifengc): makes BCast a template member function on NDIMS
1303   // instead making BinaryFunctor templates on NDIMS.
1304   void BCast(const Device& d,
1305              typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
1306              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
1307              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
1308              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
1309              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
1310              bool* error);
1311 };
1312 
1313 template <typename Device, typename T>
1314 struct ApproximateEqual {
1315   void operator()(const Device& d, typename TTypes<T>::ConstFlat x,
1316                   typename TTypes<T>::ConstFlat y, T tolerance,
1317                   typename TTypes<bool>::Flat z);
1318 };
1319 
1320 template <int NDIMS>
1321 bool AllOne(const typename Eigen::array<Eigen::DenseIndex, NDIMS>& a) {
1322   for (size_t i = 0; i < a.size(); ++i) {
1323     if (a[i] != 1) return false;
1324   }
1325   return true;
1326 }
1327 
1328 template <typename Device, typename T>
1329 struct SelectFunctor {
1330   void operator()(const Device& d, typename TTypes<T>::Flat out,
1331                   typename TTypes<bool>::ConstFlat cond_flat,
1332                   typename TTypes<T>::ConstFlat then_flat,
1333                   typename TTypes<T>::ConstFlat else_flat);
1334 };
1335 
1336 template <typename Device, typename T>
1337 struct SelectScalarFunctor {
1338   void operator()(const Device& d, typename TTypes<T>::Flat out,
1339                   typename TTypes<bool>::ConstScalar cond,
1340                   typename TTypes<T>::ConstFlat then_flat,
1341                   typename TTypes<T>::ConstFlat else_flat);
1342 };
1343 
1344 template <typename Device, typename T>
1345 struct BatchSelectFunctor {
1346   void operator()(const Device& d,
1347                   typename TTypes<T>::Matrix output_flat_outer_dims,
1348                   TTypes<bool>::ConstVec cond_vec,
1349                   typename TTypes<T>::ConstMatrix then_flat_outer_dims,
1350                   typename TTypes<T>::ConstMatrix else_flat_outer_dims);
1351 };
1352 
1353 template <typename Device, typename T, int NDIMS>
1354 struct BCastSelectFunctor {
1355   void operator()(const Device& d,
1356                   typename TTypes<T, NDIMS>::Tensor output_tensor,
1357                   typename TTypes<bool, NDIMS>::ConstTensor cond_tensor,
1358                   typename TTypes<T, NDIMS>::ConstTensor then_tensor,
1359                   typename TTypes<T, NDIMS>::ConstTensor else_tensor,
1360                   typename Eigen::array<Eigen::DenseIndex, NDIMS> cond_bcast,
1361                   typename Eigen::array<Eigen::DenseIndex, NDIMS> then_bcast,
1362                   typename Eigen::array<Eigen::DenseIndex, NDIMS> else_bcast);
1363 };
1364 
1365 }  // end namespace functor
1366 }  // end namespace tensorflow
1367 
1368 #endif  // TENSORFLOW_CORE_KERNELS_CWISE_OPS_H_
1369