• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_H_
17 #define TENSORFLOW_CORE_KERNELS_CWISE_OPS_H_
18 
19 #define _USE_MATH_DEFINES
20 #include <cmath>
21 #include <functional>
22 #include <type_traits>
23 
24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25 #include "tensorflow/core/framework/bounds_check.h"
26 #include "tensorflow/core/framework/numeric_types.h"
27 #include "tensorflow/core/framework/tensor_types.h"
28 
29 namespace Eigen {
30 namespace internal {
31 
32 #if GOOGLE_CUDA
33 template <>
34 struct scalar_arg_op<std::complex<float>> {
35   typedef typename Eigen::NumTraits<std::complex<float>>::Real result_type;
36   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()(
37       const std::complex<float>& a) const {
38     return ::atan2f(a.imag(), a.real());
39   }
40 };
41 
42 template <>
43 struct scalar_arg_op<std::complex<double>> {
44   typedef typename Eigen::NumTraits<std::complex<double>>::Real result_type;
45   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double operator()(
46       const std::complex<double>& a) const {
47     return ::atan2(a.imag(), a.real());
48   }
49 };
50 #endif
51 
52 #if EIGEN_HAS_CXX11_MATH == 0
53 template <typename T>
54 struct scalar_asinh_op {
55   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a) const {
56     return static_cast<T>(std::asinh(a));
57   }
58 };
59 template <typename T>
60 struct functor_traits<scalar_asinh_op<T>> {
61   enum { Cost = 5 * NumTraits<T>::MulCost, PacketAccess = false };
62 };
63 
64 template <typename T>
65 struct scalar_acosh_op {
66   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a) const {
67     return static_cast<T>(std::acosh(a));
68   }
69 };
70 template <typename T>
71 struct functor_traits<scalar_acosh_op<T>> {
72   enum { Cost = 5 * NumTraits<T>::MulCost, PacketAccess = false };
73 };
74 
75 template <typename T>
76 struct scalar_atanh_op {
77   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a) const {
78     return static_cast<T>(std::atanh(a));
79   }
80 };
81 template <typename T>
82 struct functor_traits<scalar_atanh_op<T>> {
83   enum { Cost = 5 * NumTraits<T>::MulCost, PacketAccess = false };
84 };
85 #endif
86 
87 template <typename Scalar, typename Exponent>
88 struct safe_scalar_binary_pow_op {
89   static_assert(std::is_integral<Scalar>::value, "Integer type expected");
90   static_assert(std::is_integral<Exponent>::value &&
91                     std::is_signed<Exponent>::value,
92                 "Signed integer type expected");
93 
94   bool* const error;
95 
96   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_scalar_binary_pow_op(bool* error)
97       : error(error) {}
98 
99   EIGEN_DEVICE_FUNC inline Scalar operator()(const Scalar& a,
100                                              const Exponent& b) const {
101     const Exponent safe_b = tensorflow::internal::SubtleMustCopy(b);
102     if (TF_PREDICT_TRUE(safe_b >= 0)) {
103       return numext::pow(a, safe_b);
104     } else {
105       *error = true;
106       return 0;
107     }
108   }
109 };
110 
111 template <typename Scalar, typename Exponent>
112 struct functor_traits<safe_scalar_binary_pow_op<Scalar, Exponent>> {
113   enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false };
114 };
115 
116 template <typename T, typename DivOrMod>
117 struct safe_div_or_mod_op {
118   static_assert(std::is_integral<T>::value, "Integer type expected");
119 
120   bool* const error;
121 
122   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_div_or_mod_op(bool* error)
123       : error(error) {}
124 
125   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a,
126                                                      const T& b) const {
127     const T safe_b = tensorflow::internal::SubtleMustCopy(b);
128     if (TF_PREDICT_TRUE(safe_b != 0)) {
129       // Avoid FPE for INT_MIN/-1.
130       const T safe_a = tensorflow::internal::SubtleMustCopy(a);
131       if (TF_PREDICT_FALSE(std::is_signed<T>::value &&
132                            safe_a == std::numeric_limits<T>::min() &&
133                            safe_b == T(-1))) {
134         // Prefer to overflow 'a' instead of crashing.
135         return DivOrMod()(-safe_a, 1);
136       }
137       return DivOrMod()(safe_a, safe_b);
138     } else {
139       *error = true;
140       return 0;
141     }
142   }
143 };
144 
145 template <typename T, typename DivOrMod>
146 struct functor_traits<safe_div_or_mod_op<T, DivOrMod>> {
147   enum {
148     Cost = functor_traits<DivOrMod>::Cost + NumTraits<T>::AddCost,
149     PacketAccess = false,
150   };
151 };
152 
153 template <typename T, typename Binary>
154 struct no_nan_op {
155   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a,
156                                                      const T& b) const {
157     if (b != T(0)) {
158       return Binary()(a, b);
159     } else {
160       return T(0);
161     }
162   }
163   template <typename Packet>
164   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a,
165                                                         const Packet& b) const {
166     const Packet mask = pcmp_eq(b, pzero(b));
167     const Packet quotient = Binary().packetOp(a, b);
168     return pandnot(quotient, mask);
169   }
170 };
171 
172 template <typename T, bool IsComplex = Eigen::NumTraits<T>::IsComplex>
173 struct div_no_nan_op;
174 
175 template <typename T>
176 struct div_no_nan_op<T, /*IsComplex=*/false>
177     : public no_nan_op<T, scalar_quotient_op<T>> {
178 };
179 
180 template <typename T>
181 struct functor_traits<div_no_nan_op<T, /*IsComplex=*/false>> {
182   enum {
183     Cost = functor_traits<scalar_quotient_op<T>>::Cost + NumTraits<T>::AddCost,
184     PacketAccess = true,
185   };
186 };
187 
188 // Whether or not complex division produces a NaN depends on the underlying
189 // implementation. Some compilers (e.g. gcc) use a simple method that divides
190 // by |b|^2, which may underflow to 0 for b != 0.
191 template <typename T>
192 struct div_no_nan_op<T, /*IsComplex=*/true> {
193   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a,
194                                                      const T& b) const {
195     if (b == T(0)) {
196       return T(0);
197     } else {
198       // If the numerator is zero, then the result must be zero even if |b|^2
199       // underflows to zero.
200       const T numerator =
201           scalar_product_op<T>()(a, scalar_conjugate_op<T>()(b));
202       if (numerator == T(0)) {
203         return T(0);
204       }
205     }
206     return scalar_quotient_op<T>()(a, b);
207   }
208   template <typename Packet>
209   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a,
210                                                         const Packet& b) const {
211     const Packet numerator = pmul(a, pconj(b));
212     const Packet mask = por(pcmp_eq(b, pzero(a)), pcmp_eq(numerator, pzero(a)));
213     const Packet quotient = pdiv(a, b);
214     return pandnot(quotient, mask);
215   }
216 };
217 
218 template <typename T>
219 struct functor_traits<div_no_nan_op<T, /*IsComplex=*/true>> {
220   enum {
221     Cost = functor_traits<scalar_quotient_op<T>>::Cost + NumTraits<T>::MulCost,
222     PacketAccess = packet_traits<T>::HasMul && packet_traits<T>::HasDiv &&
223                    packet_traits<T>::HasConj,
224   };
225 };
226 
227 template <typename T>
228 struct mul_no_nan_op : public no_nan_op<T, scalar_product_op<T>> {
229 };
230 
231 template <typename T>
232 struct functor_traits<mul_no_nan_op<T>> {
233   enum {
234     Cost = functor_traits<scalar_product_op<T>>::Cost + NumTraits<T>::AddCost,
235     PacketAccess = true,
236   };
237 };
238 
239 // scalar_left and scalar_right are template helpers to partially
240 // apply a binary function.
241 //
242 // Suppose Binary is a binary functor f(x, y), scalar_left<> is a
243 // unary functor g_x(y) = f(x, y), where x is provided via the
244 // constructor. Similarly, scalar_right<> is a unary functor g_y(x) =
245 // f(x, y).
246 
247 template <typename Tout, typename Tin, typename Binary,
248           bool is_scalar_in_host_memory = false>
249 struct scalar_left : private Binary {
250   using result_type = Tout;
251   using TinPacket = typename Eigen::internal::packet_traits<Tin>::type;
252 
253   const Tin* left;
254   TinPacket left_packet;  // initialized iff is_scalar_in_host_memory == true
255 
256   inline scalar_left(const scalar_left& other) = default;
257 
258   template <typename... Args>
259   EIGEN_DEVICE_FUNC inline explicit scalar_left(const Tin* c, Args... args)
260       : Binary(args...), left(c) {
261     if (is_scalar_in_host_memory) {
262       left_packet = Eigen::internal::pset1<TinPacket>(*left);
263     }
264   }
265 
266   EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& right) const {
267     return Binary::operator()(*left, right);
268   }
269 
270   template <typename Packet,
271             typename std::enable_if<!is_scalar_in_host_memory ||
272                                         !std::is_same<TinPacket, Packet>::value,
273                                     int>::type = 0>
274   EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& right_packet) const {
275     const Packet left_packet = Eigen::internal::pset1<Packet>(*left);
276     return Binary::packetOp(left_packet, right_packet);
277   }
278 
279   template <typename Packet,
280             typename std::enable_if<is_scalar_in_host_memory &&
281                                         std::is_same<TinPacket, Packet>::value,
282                                     int>::type = 0>
283   EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& right_packet) const {
284     return Binary::packetOp(left_packet, right_packet);
285   }
286 };
287 
288 template <typename Tout, typename Tin, typename Binary,
289           bool is_scalar_in_host_memory>
290 struct functor_traits<
291     scalar_left<Tout, Tin, Binary, is_scalar_in_host_memory>> {
292   enum {
293     Cost = functor_traits<Binary>::Cost,
294     PacketAccess = functor_traits<Binary>::PacketAccess,
295   };
296 };
297 
298 template <typename Tout, typename Tin, typename Binary,
299           bool is_scalar_in_host_memory = false>
300 struct scalar_right : private Binary {
301   using result_type = Tout;
302   using TinPacket = typename Eigen::internal::packet_traits<Tin>::type;
303 
304   const Tin* right;
305   TinPacket right_packet;  // initialized iff is_scalar_in_host_memory == true
306 
307   inline scalar_right(const scalar_right& other) = default;
308 
309   template <typename... Args>
310   EIGEN_DEVICE_FUNC inline explicit scalar_right(const Tin* c, Args... args)
311       : Binary(args...), right(c) {
312     if (is_scalar_in_host_memory) {
313       right_packet = Eigen::internal::pset1<TinPacket>(*right);
314     }
315   }
316 
317   EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& left) const {
318     return Binary::operator()(left, *right);
319   }
320 
321   template <typename Packet,
322             typename std::enable_if<!is_scalar_in_host_memory ||
323                                         !std::is_same<TinPacket, Packet>::value,
324                                     int>::type = 0>
325   EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& left_packet) const {
326     const Packet right_packet = Eigen::internal::pset1<Packet>(*right);
327     return Binary::packetOp(left_packet, right_packet);
328   }
329 
330   template <typename Packet,
331             typename std::enable_if<is_scalar_in_host_memory &&
332                                         std::is_same<TinPacket, Packet>::value,
333                                     int>::type = 0>
334   EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& left_packet) const {
335     return Binary::packetOp(left_packet, right_packet);
336   }
337 };
338 
339 template <typename Tout, typename Tin, typename Binary,
340           bool is_scalar_in_host_memory>
341 struct functor_traits<
342     scalar_right<Tout, Tin, Binary, is_scalar_in_host_memory>> {
343   enum {
344     Cost = functor_traits<Binary>::Cost,
345     PacketAccess = functor_traits<Binary>::PacketAccess,
346   };
347 };
348 
349 // similar to std::equal_to, but with the DEVICE_FUNC qualifier
350 template <class T>
351 struct equal_to : std::function<bool(T, T)> {
352   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x,
353                                                         const T& y) const {
354     return x == y;
355   }
356 };
357 
358 // similar to std::not_equal_to, but with the DEVICE_FUNC qualifier
359 template <class T>
360 struct not_equal_to : std::function<bool(T, T)> {
361   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x,
362                                                         const T& y) const {
363     return x != y;
364   }
365 };
366 
367 // similar to std::greater, but with the DEVICE_FUNC qualifier
368 template <class T>
369 struct greater : std::function<bool(T, T)> {
370   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x,
371                                                         const T& y) const {
372     return x > y;
373   }
374 };
375 
376 // similar to std::less, but with the DEVICE_FUNC qualifier
377 template <class T>
378 struct less : std::function<bool(T, T)> {
379   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x,
380                                                         const T& y) const {
381     return x < y;
382   }
383 };
384 
385 // similar to std::greater_equal, but with the DEVICE_FUNC qualifier
386 template <class T>
387 struct greater_equal : std::function<bool(T, T)> {
388   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x,
389                                                         const T& y) const {
390     return x >= y;
391   }
392 };
393 
394 // similar to std::less_equal, but with the DEVICE_FUNC qualifier
395 template <class T>
396 struct less_equal : std::function<bool(T, T)> {
397   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x,
398                                                         const T& y) const {
399     return x <= y;
400   }
401 };
402 
403 // Functor that enables squared difference functor.
404 template <typename Scalar>
405 struct scalar_squared_difference_op {
406   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
407   operator()(const Scalar& a, const Scalar& b) const {
408     const Scalar v = scalar_difference_op<Scalar>()(a, b);
409     return scalar_product_op<Scalar>()(v, scalar_conjugate_op<Scalar>()(v));
410   }
411   template <typename Packet>
412   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a,
413                                                         const Packet& b) const {
414     const Packet v = scalar_difference_op<Scalar>().packetOp(a, b);
415     return scalar_product_op<Scalar>().packetOp(
416         v, scalar_conjugate_op<Scalar>().packetOp(v));
417   }
418 };
419 
420 template <typename Scalar>
421 struct functor_traits<scalar_squared_difference_op<Scalar>> {
422   enum {
423     Cost = functor_traits<scalar_difference_op<Scalar>>::Cost +
424            functor_traits<scalar_conjugate_op<Scalar>>::Cost +
425            functor_traits<scalar_product_op<Scalar>>::Cost,
426     PacketAccess = functor_traits<scalar_difference_op<Scalar>>::PacketAccess &&
427                    functor_traits<scalar_conjugate_op<Scalar>>::PacketAccess &&
428                    functor_traits<scalar_product_op<Scalar>>::PacketAccess
429   };
430 };
431 
432 // TODO(b/32239616): This kernel should be moved into Eigen and vectorized.
433 template <typename T, typename Enable = void>
434 struct google_floor_div {
435   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x,
436                                                      const T& y) const {
437     const T z = x / y;
438     // Subtract one if there is a remainder and if the inputs have opposite
439     // signs. This approach avoids unnecessary overflows.
440     return z * y != x && (x < T(0) != y < T(0)) ? z - T(1) : z;
441   }
442   template <typename Packet>
443   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x,
444                                                         const Packet& y) const {
445     Packet zeros = pzero(x);
446     Packet x_mask = pcmp_lt(x, zeros);
447     Packet y_mask = pcmp_lt(y, zeros);
448     Packet x_div_y = pdiv(x, y);
449     Packet x_div_y_times_y = pmul(x_div_y, y);
450     return pselect(por(peq(x_div_y_times_y, x), peq(x_mask, y_mask)), x_div_y,
451                    psub(x_div_y, pones(x)));
452   }
453 };
454 
455 template <typename T>
456 struct google_floor_div<
457     T, typename std::enable_if<std::is_unsigned<T>::value>::type> {
458   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x,
459                                                      const T& y) const {
460     return x / y;
461   }
462   template <typename Packet>
463   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x,
464                                                         const Packet& y) const {
465     return pdiv(x, y);
466   }
467 };
468 
469 template <typename Scalar>
470 struct functor_traits<google_floor_div<Scalar>> {
471   enum {
472     Cost = 2 * Eigen::internal::scalar_div_cost<
473                    Scalar, packet_traits<Scalar>::HasDiv>::value +
474            NumTraits<Scalar>::AddCost,
475     PacketAccess = packet_traits<Scalar>::HasDiv
476   };
477 };
478 
479 template <typename T, typename Enable = void>
480 struct google_floor_div_real {
481   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x,
482                                                      const T& y) const {
483     return Eigen::numext::floor(x / y);
484   }
485   template <typename Packet>
486   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x,
487                                                         const Packet& y) const {
488     return pfloor(pdiv(x, y));
489   }
490 };
491 
492 template <typename Scalar>
493 struct functor_traits<google_floor_div_real<Scalar>> {
494   enum {
495     Cost = 2 * Eigen::internal::scalar_div_cost<
496                    Scalar, packet_traits<Scalar>::HasDiv>::value +
497            2 * NumTraits<Scalar>::AddCost,
498     PacketAccess =
499         packet_traits<Scalar>::HasDiv && packet_traits<Scalar>::HasFloor
500   };
501 };
502 
503 // TODO(rmlarsen): Add vectorized mod & fmod in Eigen and use it here.
504 template <typename T>
505 struct google_floor_fmod {
506   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x,
507                                                      const T& y) const {
508     // EIGEN_STATIC_ASSERT(NUMERIC_TYPE_MUST_BE_REAL);
509     T trunc_mod = scalar_fmod_op<T>()(x, y);
510     return trunc_mod != T(0) && (y < T(0) != trunc_mod < T(0)) ? trunc_mod + y
511                                                                : trunc_mod;
512   }
513 };
514 
515 template <typename Scalar>
516 struct functor_traits<google_floor_fmod<Scalar>> {
517   enum {
518     Cost = functor_traits<Eigen::internal::scalar_fmod_op<Scalar>>::Cost +
519            NumTraits<Scalar>::AddCost,
520     PacketAccess = false
521   };
522 };
523 
524 // TODO(rmlarsen): Add vectorized mod & fmod in Eigen and use it here.
525 template <typename T>
526 struct google_floor_mod {
527   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x,
528                                                      const T& y) const {
529     // EIGEN_STATIC_ASSERT(!NUMERIC_TYPE_MUST_BE_REAL);
530     T trunc_mod = Eigen::internal::scalar_mod2_op<T>()(x, y);
531     return trunc_mod != T(0) && (y < T(0) != trunc_mod < T(0)) ? trunc_mod + y
532                                                                : trunc_mod;
533   }
534 };
535 
536 template <typename Scalar>
537 struct functor_traits<google_floor_mod<Scalar>> {
538   enum {
539     Cost = functor_traits<Eigen::internal::scalar_mod2_op<Scalar>>::Cost +
540            NumTraits<Scalar>::AddCost,
541     PacketAccess = false
542   };
543 };
544 
545 #if EIGEN_COMP_GNUC && __cplusplus > 199711L
546 #define DISABLE_FLOAT_EQUALITY_WARNING \
547   _Pragma("GCC diagnostic push")       \
548       _Pragma("GCC diagnostic ignored \"-Wfloat-equal\"")
549 #define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop")
550 #else
551 #define DISABLE_FLOAT_EQUALITY_WARNING
552 #define ENABLE_FLOAT_EQUALITY_WARNING
553 #endif
554 
555 template <typename Scalar, bool IsInteger = Eigen::NumTraits<Scalar>::IsInteger,
556           bool HasRint = packet_traits<Scalar>::HasRint>
557 struct scalar_round_half_to_even_op {
558   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
559   operator()(const Scalar& x) const {
560     EIGEN_STATIC_ASSERT((!NumTraits<Scalar>::IsComplex),
561                         NUMERIC_TYPE_MUST_BE_REAL)
562 
563     const Scalar round_val = Eigen::numext::floor(x + Scalar(0.5));
564     const Scalar fraction = round_val - x;
565     if (TF_PREDICT_FALSE(fraction == Scalar(.5))) {
566       return Scalar(2) * Eigen::numext::floor(Scalar(.5) * x + Scalar(0.5));
567     } else {
568       return round_val;
569     }
570   }
571 
572   template <typename Packet>
573   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const {
574     Packet half = pset1<Packet>(Scalar(0.5));
575     Packet round_val = pfloor(padd(x, half));
576     Packet fraction = psub(round_val, x);
577     Packet half_mask = pcmp_eq(fraction, half);
578     bool any_halves = predux_any(half_mask);
579     if (TF_PREDICT_FALSE(any_halves)) {
580       Packet two = pset1<Packet>(Scalar(2));
581       Packet nearest_even = pmul(two, pfloor(pmadd(half, x, half)));
582       return pselect(half_mask, nearest_even, round_val);
583     } else {
584       return round_val;
585     }
586   }
587 };
588 
589 template <typename Scalar>
590 struct scalar_round_half_to_even_op<Scalar, true, false> {
591   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
592   operator()(const Scalar& x) const {
593     return x;
594   }
595   template <typename Packet>
596   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const {
597     return x;
598   }
599 };
600 
601 template <typename Scalar>
602 struct scalar_round_half_to_even_op<Scalar, false, true> {
603   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
604   operator()(const Scalar& x) const {
605     return Eigen::numext::rint(x);
606   }
607   template <typename Packet>
608   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const {
609     return print(x);
610   }
611 };
612 
613 template <typename Scalar>
614 struct functor_traits<scalar_round_half_to_even_op<Scalar>> {
615   enum {
616     Cost = Eigen::NumTraits<Scalar>::IsInteger ? 0
617                                                : 4 * NumTraits<Scalar>::AddCost,
618     PacketAccess = packet_traits<Scalar>::HasFloor &&
619                    packet_traits<Scalar>::HasAdd &&
620                    packet_traits<Scalar>::HasMul,
621   };
622 };
623 
624 template <typename Scalar, bool IsInteger = Eigen::NumTraits<Scalar>::IsInteger>
625 struct scalar_round_up_op {
626   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
627   operator()(const Scalar& x) const {
628     EIGEN_STATIC_ASSERT((!NumTraits<Scalar>::IsComplex),
629                         NUMERIC_TYPE_MUST_BE_REAL)
630     return Eigen::numext::floor(x + Scalar(0.5));
631   }
632 
633   template <typename Packet>
634   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const {
635     return pfloor(padd(x, pset1<Packet>(0.5)));
636   }
637 };
638 
639 template <typename Scalar>
640 struct scalar_round_up_op<Scalar, true> {
641   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
642   operator()(const Scalar& x) const {
643     return x;
644   }
645 
646   template <typename Packet>
647   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const {
648     return x;
649   }
650 };
651 
652 template <typename Scalar, bool IsInteger>
653 struct functor_traits<scalar_round_up_op<Scalar, IsInteger>> {
654   enum {
655     Cost = IsInteger ? 0 : 4 * NumTraits<Scalar>::AddCost,
656     PacketAccess = IsInteger || packet_traits<Scalar>::HasFloor
657   };
658 };
659 
660 #undef ENABLE_FLOAT_EQUALITY_WARNING
661 #undef DISABLE_FLOAT_EQUALITY_WARNING
662 
663 template <typename Scalar>
664 struct bitwise_xor_op {
665   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
666   operator()(const Scalar& x, const Scalar& y) const {
667     return x ^ y;
668   }
669   template <typename Packet>
670   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a,
671                                                         const Packet& b) const {
672     return Eigen::internal::pxor(a, b);
673   }
674 };
675 
676 template <typename Scalar>
677 struct functor_traits<bitwise_xor_op<Scalar>> {
678   enum { Cost = Eigen::NumTraits<Scalar>::AddCost, PacketAccess = true };
679 };
680 
681 template <typename Scalar>
682 struct xlogy_op {
683   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
684   operator()(const Scalar& x, const Scalar& y) const {
685     if (x == Scalar(0.)) {
686       return Scalar(0.);
687     }
688     return x * numext::log(y);
689   }
690   template <typename Packet>
691   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x,
692                                                         const Packet& y) const {
693     Packet zeros = pzero(x);
694     Packet mask = pcmp_eq(x, zeros);
695     scalar_log_op<Scalar> log_op;
696     Packet log_y = log_op.packetOp(y);
697     Packet x_log_y = pmul(x, log_y);
698     return pselect(mask, x, x_log_y);
699   }
700 };
701 
702 template <typename Scalar>
703 struct functor_traits<xlogy_op<Scalar>> {
704   enum {
705     Cost = functor_traits<scalar_log_op<Scalar>>::Cost +
706            Eigen::NumTraits<Scalar>::MulCost,
707     PacketAccess = functor_traits<scalar_log_op<Scalar>>::PacketAccess
708   };
709 };
710 
711 template <typename Scalar>
712 struct xlog1py_op {
713   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
714   operator()(const Scalar& x, const Scalar& y) const {
715     if (x == Scalar(0.)) {
716       return Scalar(0.);
717     }
718     return x * numext::log1p(y);
719   }
720   template <typename Packet>
721   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x,
722                                                         const Packet& y) const {
723     Packet zeros = pzero(x);
724     Packet mask = pcmp_eq(x, zeros);
725     scalar_log1p_op<Scalar> log1p_op;
726     Packet log1p_y = log1p_op.packetOp(y);
727     Packet x_log1p_y = pmul(x, log1p_y);
728     return pselect(mask, x, x_log1p_y);
729   }
730 };
731 
732 template <typename Scalar>
733 struct functor_traits<xlog1py_op<Scalar>> {
734   enum {
735     Cost = functor_traits<scalar_log1p_op<Scalar>>::Cost +
736            Eigen::NumTraits<Scalar>::MulCost,
737 #if TENSORFLOW_USE_ROCM
738     PacketAccess = false,
739 #else
740     PacketAccess = functor_traits<scalar_log1p_op<Scalar>>::PacketAccess
741 #endif
742   };
743 };
744 
745 template <typename Scalar>
746 struct xdivy_op {
747   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
748   operator()(const Scalar& x, const Scalar& y) const {
749     if (x == Scalar(0.)) {
750       return Scalar(0.);
751     }
752     return x / y;
753   }
754   template <typename Packet>
755   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x,
756                                                         const Packet& y) const {
757     Packet zeros = pzero(x);
758     Packet mask = pcmp_eq(x, zeros);
759     Packet x_div_y = pdiv(x, y);
760     return pselect(mask, x, x_div_y);
761   }
762 };
763 
764 template <typename Scalar>
765 struct functor_traits<xdivy_op<Scalar>> {
766   enum {
767     Cost =
768         Eigen::NumTraits<Scalar>::AddCost +
769         Eigen::internal::scalar_div_cost<Scalar,
770                                          packet_traits<Scalar>::HasDiv>::value,
771     PacketAccess = packet_traits<Scalar>::HasDiv
772   };
773 };
774 
775 template <typename T>
776 struct scalar_erfinv_op {
777   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const {
778     constexpr T half = T(0.5);
779     T y = numext::ndtri(half * x + half);
780     constexpr T half_sqrt = T(M_SQRT1_2);
781     return y * half_sqrt;
782   }
783   template <typename Packet>
784   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const {
785     Packet half = pset1<Packet>(T(0.5));
786     Packet y = pndtri<Packet>(pmadd(half, x, half));
787     Packet half_sqrt = pset1<Packet>(T(M_SQRT1_2));
788     return pmul(y, half_sqrt);
789   }
790 };
791 
792 template <typename T>
793 struct functor_traits<scalar_erfinv_op<T>> {
794   enum {
795     Cost = functor_traits<scalar_ndtri_op<T>>::Cost + NumTraits<T>::AddCost,
796     PacketAccess = packet_traits<T>::HasNdtri,
797   };
798 };
799 
800 }  // end namespace internal
801 }  // end namespace Eigen
802 
803 namespace tensorflow {
804 namespace functor {
805 
806 ////////////////////////////////////////////////////////////////////////////////
807 // Helpers
808 ////////////////////////////////////////////////////////////////////////////////
809 
810 // Base template for functors whose input scalar type is T and
811 // output scalar type is R.
812 template <typename T, typename F, typename R = T>
813 struct base {
814   // func defines operator() and its vectorized version packetOp().
815   typedef F func;
816 
817   // If true, the functor's corresponding binary op will instantiate
818   // specialized kernels to perform an optimized broadcast
819   // operation. Each functor for which this is enabled increases the
820   // code size, so by default this is disabled for binary functors and
821   // is enabled on a per-op basis as needed.
822   static constexpr bool use_bcast_optimization = false;
823 
824   // operator() has the signature:
825   //  out_type operator()(in_type in0, in_type in1 ...)
826   typedef R out_type;
827   typedef T in_type;
828 
829   // TensorFlow provides tensor-ized version of "func". Roughly
830   // speaking, the tensorflow operation has the signature:
831   //   tout_type op(tin_type in0)
832   //   tout_type op(tin_type in0, tin_type in1)
833   //   tout_type op(tin_type in0, in_type scalar)
834   typedef typename TTypes<out_type>::Flat tout_type;
835   typedef typename TTypes<in_type>::ConstFlat tin_type;
836   typedef typename TTypes<in_type>::ConstScalar tscalar_type;
837 
838   // Whether the functor can error out.  Currently applies only to integer
839   // div and mod.
840   static constexpr bool has_errors = false;
841 };
842 
843 // For now, we only apply certain speed optimization for
844 // float/double's broadcast binary op.
845 template <typename T>
846 struct use_bcast_optimization {
847   static constexpr bool value = false;
848 };
849 
850 template <>
851 struct use_bcast_optimization<float> {
852   static constexpr bool value = true;
853 };
854 
855 template <>
856 struct use_bcast_optimization<double> {
857   static constexpr bool value = true;
858 };
859 
860 ////////////////////////////////////////////////////////////////////////////////
861 // Unary functors
862 ////////////////////////////////////////////////////////////////////////////////
863 
864 // abs(x) = |x|
865 // neg(x) = - x
866 // inverse(x) = 1 / x
867 // square(x) = x^2
868 // sqrt(x) = x^(1/2)
869 // rsqrt(x) = x^(-1/2)
870 // exp(x) = e^x
871 // expm1(x) = e^x - 1
872 // log(x) = natural logarithm of x
873 // log1p(x) = natural logarithm of 1 + x
874 // tanh = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
875 // sigmoid = 1 / (1 + exp(-x))  // a.k.a, logistic
876 //
877 // NOTE: We may eventually implement common functions used in NN
878 // here. E.g., rectifier, softplus, derivatives of tanh, sigmod, etc.
879 // For reference, see speech/lstm/eigen_functors.h.
880 
881 template <typename T>
882 struct abs : base<T, Eigen::internal::scalar_abs_op<T>,
883                   typename Eigen::internal::scalar_abs_op<T>::result_type> {};
884 
885 template <typename T>
886 struct neg : base<T, Eigen::internal::scalar_opposite_op<T>> {};
887 
888 template <typename T>
889 struct inverse : base<T, Eigen::internal::scalar_inverse_op<T>> {};
890 
891 template <typename T>
892 struct square : base<T, Eigen::internal::scalar_square_op<T>> {};
893 
894 template <typename T>
895 struct sqrt : base<T, Eigen::internal::scalar_sqrt_op<T>> {};
896 
897 template <typename T>
898 struct rsqrt : base<T, Eigen::internal::scalar_rsqrt_op<T>> {};
899 
900 template <typename T>
901 struct exp : base<T, Eigen::internal::scalar_exp_op<T>> {};
902 
903 template <typename T>
904 struct expm1 : base<T, Eigen::internal::scalar_expm1_op<T>> {};
905 
906 template <typename T>
907 struct log : base<T, Eigen::internal::scalar_log_op<T>> {};
908 
909 template <typename T>
910 struct log1p : base<T, Eigen::internal::scalar_log1p_op<T>> {};
911 
912 template <typename T>
913 struct sign : base<T, Eigen::internal::scalar_sign_op<T>> {};
914 
915 template <typename T>
916 struct sinh : base<T, Eigen::internal::scalar_sinh_op<T>> {};
917 
918 template <typename T>
919 struct cosh : base<T, Eigen::internal::scalar_cosh_op<T>> {};
920 
921 template <typename T>
922 struct tanh : base<T, Eigen::internal::scalar_tanh_op<T>> {};
923 
924 template <typename T>
925 struct asinh : base<T, Eigen::internal::scalar_asinh_op<T>> {};
926 
927 template <typename T>
928 struct acosh : base<T, Eigen::internal::scalar_acosh_op<T>> {};
929 
930 template <typename T>
931 struct atanh : base<T, Eigen::internal::scalar_atanh_op<T>> {};
932 
933 template <typename T>
934 struct lgamma : base<T, Eigen::internal::scalar_lgamma_op<T>> {};
935 
936 template <typename T>
937 struct digamma : base<T, Eigen::internal::scalar_digamma_op<T>> {};
938 
939 template <typename T>
940 struct erf : base<T, Eigen::internal::scalar_erf_op<T>> {};
941 
942 template <typename T>
943 struct erfc : base<T, Eigen::internal::scalar_erfc_op<T>> {};
944 
945 template <typename T>
946 struct ndtri : base<T, Eigen::internal::scalar_ndtri_op<T>> {};
947 
948 template <typename T>
949 struct erfinv : base<T, Eigen::internal::scalar_erfinv_op<T>> {};
950 
951 template <typename T>
952 struct sigmoid : base<T, Eigen::internal::scalar_logistic_op<T>> {};
953 
954 template <typename T>
955 struct sin : base<T, Eigen::internal::scalar_sin_op<T>> {};
956 
957 template <typename T>
958 struct cos : base<T, Eigen::internal::scalar_cos_op<T>> {};
959 
960 template <typename T>
961 struct tan : base<T, Eigen::internal::scalar_tan_op<T>> {};
962 
963 template <typename T>
964 struct asin : base<T, Eigen::internal::scalar_asin_op<T>> {};
965 
966 template <typename T>
967 struct acos : base<T, Eigen::internal::scalar_acos_op<T>> {};
968 
969 template <typename T>
970 struct atan : base<T, Eigen::internal::scalar_atan_op<T>> {};
971 
972 struct logical_not : base<bool, Eigen::internal::scalar_boolean_not_op<bool>> {
973 };
974 
975 // Flip all bits. Named invert to be consistent with numpy.
976 template <typename T>
977 struct invert_op {
978   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a) const {
979     return ~a;
980   }
981 };
982 
983 template <typename T>
984 struct invert : base<T, invert_op<T>> {};
985 
986 // NOTE: std::isinf, std::isnan, std::isfinite are plain function.
987 // Therefore we need to wrap them in functors to be used with Eigen's
988 // type system.
989 template <typename T>
990 struct isinf : base<T, Eigen::internal::scalar_isinf_op<T>, bool> {};
991 
992 template <typename T>
993 struct isnan : base<T, Eigen::internal::scalar_isnan_op<T>, bool> {};
994 
995 template <typename T>
996 struct isfinite : base<T, Eigen::internal::scalar_isfinite_op<T>, bool> {};
997 
998 template <typename T>
999 struct floor : base<T, Eigen::internal::scalar_floor_op<T>> {};
1000 
1001 template <typename T>
1002 struct round : base<T, Eigen::internal::scalar_round_half_to_even_op<T>> {};
1003 
1004 template <typename T>
1005 struct ceil : base<T, Eigen::internal::scalar_ceil_op<T>> {};
1006 
1007 // Note: rint rounds half values to even, just like round_half_to_even_op.
1008 template <typename T>
1009 struct rint : base<T, Eigen::internal::scalar_rint_op<T>> {};
1010 
1011 ////////////////////////////////////////////////////////////////////////////////
1012 // Binary functors
1013 ////////////////////////////////////////////////////////////////////////////////
1014 
1015 // Binary functors:
1016 //
1017 // add(x, y) = x + y
1018 // sub(x, y) = x - y
1019 // mul(x, y) = x * y
1020 // div(x, y) = x / y
1021 // mod(x, y) = x % y         (int32 and int64 only)
1022 // fmod(x, y) = fmod(x, y)   (float and double only)
1023 // pow(x, y) = x ^ y
1024 // maximum(x, y) = x > y ? x : y
1025 // minimum(x, y) = x < y ? x : y
1026 // squared_difference(x, y) = conj(x - y) * (x - y)
1027 
1028 template <typename T>
1029 struct add : base<T, Eigen::internal::scalar_sum_op<T>> {
1030   static constexpr bool use_bcast_optimization = true;
1031 };
1032 
1033 template <typename T>
1034 struct sub : base<T, Eigen::internal::scalar_difference_op<T>> {
1035   static constexpr bool use_bcast_optimization = true;
1036 };
1037 
1038 template <typename T>
1039 struct mul : base<T, Eigen::internal::scalar_product_op<T>> {
1040   static constexpr bool use_bcast_optimization = true;
1041 };
1042 
1043 template <typename T>
1044 struct mul_no_nan : base<T, Eigen::internal::mul_no_nan_op<T>> {};
1045 
1046 template <typename T>
1047 struct div : base<T, Eigen::internal::scalar_quotient_op<T>> {};
1048 
1049 template <typename T>
1050 struct safe_div : base<T, Eigen::internal::safe_div_or_mod_op<
1051                               T, Eigen::internal::scalar_quotient_op<T>>> {
1052   static constexpr bool has_errors = true;
1053 };
1054 
1055 template <typename T>
1056 struct div_no_nan : base<T, Eigen::internal::div_no_nan_op<T>> {};
1057 
1058 template <typename T>
1059 struct fmod : base<T, Eigen::internal::scalar_fmod_op<T>> {};
1060 
1061 template <typename T>
1062 struct mod : base<T, Eigen::internal::scalar_mod2_op<T>> {};
1063 
1064 template <typename T>
1065 struct safe_mod : base<T, Eigen::internal::safe_div_or_mod_op<
1066                               T, Eigen::internal::scalar_mod2_op<T>>> {
1067   static constexpr bool has_errors = true;
1068 };
1069 
1070 template <typename T>
1071 struct floor_fmod : base<T, Eigen::internal::google_floor_fmod<T>> {};
1072 
1073 template <typename T>
1074 struct safe_floor_mod : base<T, Eigen::internal::safe_div_or_mod_op<
1075                                     T, Eigen::internal::google_floor_mod<T>>> {
1076   static constexpr bool has_errors = true;
1077 };
1078 
1079 template <typename T>
1080 struct floor_div : base<T, Eigen::internal::google_floor_div<T>> {};
1081 
1082 template <typename T>
1083 struct safe_floor_div : base<T, Eigen::internal::safe_div_or_mod_op<
1084                                     T, Eigen::internal::google_floor_div<T>>> {
1085   static constexpr bool has_errors = true;
1086 };
1087 
1088 template <typename T>
1089 struct floor_div_real : base<T, Eigen::internal::google_floor_div_real<T>> {};
1090 
1091 template <typename T>
1092 struct pow : base<T, Eigen::internal::scalar_pow_op<T, T>> {};
1093 
1094 template <typename T>
1095 struct safe_pow : base<T, Eigen::internal::safe_scalar_binary_pow_op<T, T>> {
1096   static constexpr bool has_errors = true;
1097 };
1098 
1099 // Version of safe_pow for integers which returns 0 if RHS is negative and LHS
1100 // is not 1 or -1. For use on GPUs, where we cannot raise an error.
1101 template <typename T>
1102 struct safe_pow_ignore_error_op {
1103   static_assert(std::is_integral<T>::value, "Integer type expected");
1104   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x,
1105                                                      const T& y) const {
1106     if (TF_PREDICT_FALSE(y < 0)) {
1107       if (x == T(-1)) {
1108         T trunc_mod = Eigen::internal::scalar_mod2_op<T>()(y, T(2));
1109         return trunc_mod == T(-1) ? T(-1) : T(1);
1110       }
1111       return x == T(1) ? T(1) : T(0);
1112     }
1113     return Eigen::internal::scalar_pow_op<T, T>{}(x, y);
1114   }
1115 };
1116 
1117 template <typename T>
1118 struct safe_pow_ignore_error : base<T, safe_pow_ignore_error_op<T>> {};
1119 
1120 template <typename T>
1121 struct maximum
1122     : base<T, Eigen::internal::scalar_max_op<T, T, Eigen::PropagateNaN>> {};
1123 
1124 template <typename T>
1125 struct minimum
1126     : base<T, Eigen::internal::scalar_min_op<T, T, Eigen::PropagateNaN>> {};
1127 
1128 template <typename T>
1129 struct igamma : base<T, Eigen::internal::scalar_igamma_op<T>> {};
1130 
1131 template <typename T>
1132 struct random_gamma_grad
1133     : base<T, Eigen::internal::scalar_gamma_sample_der_alpha_op<T>> {};
1134 
1135 template <typename T>
1136 struct igammac : base<T, Eigen::internal::scalar_igammac_op<T>> {};
1137 
1138 template <typename T>
1139 struct zeta : base<T, Eigen::internal::scalar_zeta_op<T>> {};
1140 
1141 template <typename T>
1142 struct polygamma : base<T, Eigen::internal::scalar_polygamma_op<T>> {};
1143 
1144 template <typename Scalar>
1145 struct scalar_atan2_op {
1146   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
1147   operator()(const Scalar& y, const Scalar& x) const {
1148 #if TENSORFLOW_USE_ROCM
1149     return static_cast<Scalar>(::atan2(y, x));
1150 #else
1151     return static_cast<Scalar>(std::atan2(y, x));
1152 #endif
1153   }
1154 };
1155 
1156 template <typename T>
1157 struct atan2 : base<T, scalar_atan2_op<T>> {};
1158 
1159 template <typename T>
1160 struct squared_difference
1161     : base<T, Eigen::internal::scalar_squared_difference_op<T>> {};
1162 
1163 template <typename T>
1164 struct xdivy : base<T, Eigen::internal::xdivy_op<T>> {};
1165 
1166 template <typename T>
1167 struct xlogy : base<T, Eigen::internal::xlogy_op<T>> {};
1168 
1169 template <typename T>
1170 struct xlog1py : base<T, Eigen::internal::xlog1py_op<T>> {};
1171 
1172 template <typename T>
1173 struct less : base<T, Eigen::internal::less<T>, bool> {};
1174 
1175 template <typename T>
1176 struct less_equal : base<T, Eigen::internal::less_equal<T>, bool> {};
1177 
1178 template <typename T>
1179 struct greater : base<T, Eigen::internal::greater<T>, bool> {};
1180 
1181 template <typename T>
1182 struct greater_equal : base<T, Eigen::internal::greater_equal<T>, bool> {};
1183 
1184 template <typename T>
1185 struct equal_to : base<T, Eigen::internal::equal_to<T>, bool> {};
1186 
1187 template <typename T>
1188 struct not_equal_to : base<T, Eigen::internal::not_equal_to<T>, bool> {};
1189 
1190 struct logical_and : base<bool, Eigen::internal::scalar_boolean_and_op> {};
1191 
1192 struct logical_or : base<bool, Eigen::internal::scalar_boolean_or_op> {};
1193 
1194 template <typename T>
1195 struct bitwise_and_op {
1196   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x,
1197                                                      const T& y) const {
1198     return x & y;
1199   }
1200 };
1201 
1202 template <typename T>
1203 struct bitwise_or_op {
1204   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x,
1205                                                      const T& y) const {
1206     return x | y;
1207   }
1208 };
1209 
1210 template <typename T>
1211 struct bitwise_and : base<T, bitwise_and_op<T>> {};
1212 
1213 template <typename T>
1214 struct bitwise_or : base<T, bitwise_or_op<T>> {};
1215 
1216 template <typename T>
1217 struct bitwise_xor : base<T, Eigen::internal::bitwise_xor_op<T>> {};
1218 
1219 template <typename T>
1220 struct left_shift_op {
1221   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x,
1222                                                      const T& y) const {
1223     // Avoids UB: don't shift by larger than the bitwidth of T, and
1224     // performs left shifts as unsigned shifts.
1225     T y_clamped = y;
1226     if (y_clamped < 0) {
1227       y_clamped = 0;
1228     } else if (y_clamped > sizeof(T) * CHAR_BIT - 1) {
1229       y_clamped = sizeof(T) * CHAR_BIT - 1;
1230     }
1231     using U = typename std::make_unsigned<T>::type;
1232     return static_cast<T>(static_cast<U>(x) << static_cast<U>(y_clamped));
1233   }
1234 };
1235 
1236 template <typename T>
1237 struct right_shift_op {
1238   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x,
1239                                                      const T& y) const {
1240     // Avoids UB: don't shift by larger than the bitwidth of T.
1241     T y_clamped = y;
1242     if (y_clamped < 0) {
1243       y_clamped = 0;
1244     } else if (y_clamped > sizeof(T) * CHAR_BIT - 1) {
1245       y_clamped = sizeof(T) * CHAR_BIT - 1;
1246     }
1247     // Technically right shifts of signed integers are not necessarily
1248     // arithmetic shifts according to the C++ standard. However in practice most
1249     // implementations are arithmetic shifts. If this proves to be a problem in
1250     // practice, we may need to use an alternative implementation.
1251     return x >> y_clamped;
1252   }
1253 };
1254 
1255 template <typename T>
1256 struct left_shift : base<T, left_shift_op<T>> {};
1257 
1258 template <typename T>
1259 struct right_shift : base<T, right_shift_op<T>> {};
1260 
1261 template <typename T>
1262 struct make_complex_func {
1263   typedef std::complex<T> result_type;
1264   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(T real,
1265                                                                T imag) const {
1266     return std::complex<T>(real, imag);
1267   }
1268 };
1269 
1270 template <typename T>
1271 struct make_complex : base<T, make_complex_func<T>, std::complex<T>> {};
1272 
1273 template <typename T>
1274 struct get_real
1275     : base<T, Eigen::internal::scalar_real_op<T>, typename T::value_type> {};
1276 
1277 template <typename T>
1278 struct get_imag
1279     : base<T, Eigen::internal::scalar_imag_op<T>, typename T::value_type> {};
1280 
1281 template <typename T>
1282 struct get_angle
1283     : base<T, Eigen::internal::scalar_arg_op<T>, typename T::value_type> {};
1284 
1285 template <typename T>
1286 struct conj : base<T, Eigen::internal::scalar_conjugate_op<T>> {};
1287 
1288 ////////////////////////////////////////////////////////////////////////////////
1289 // Functors takes 1 or 2 tensors, computes the base functor on
1290 // coefficient of the input tensors and puts the results in the output
1291 // tensor.
1292 ////////////////////////////////////////////////////////////////////////////////
1293 template <typename Device, typename Functor>
1294 struct UnaryFunctor {
1295   // Computes on device "d": out[i] = Functor(in[i])
1296   void operator()(const Device& d, typename Functor::tout_type out,
1297                   typename Functor::tin_type in);
1298 };
1299 
1300 template <typename Device, typename Functor, typename Targ>
1301 struct UnaryFunctorWithArg {
1302   // Computes on device "d": out[i] = Functor(in[i])
1303   void operator()(const Device& d, typename Functor::tout_type out,
1304                   typename Functor::tin_type in, Targ val);
1305 };
1306 
1307 template <typename Device, typename Functor, int NDIMS,
1308           bool has_errors = Functor::has_errors>
1309 struct BinaryFunctor {
1310   // Computes on device "d": out[i] = Functor(in0[i], in1[i])
1311   void operator()(const Device& d, typename Functor::tout_type out,
1312                   typename Functor::tin_type in0,
1313                   typename Functor::tin_type in1, bool* error);
1314 
1315   // Computes on device "d": out[i] = Functor(scalar[0], in[i])
1316   void Left(const Device& d, typename Functor::tout_type out,
1317             typename Functor::tscalar_type scalar,
1318             typename Functor::tin_type in, bool* error);
1319 
1320   // Computes on device "d": out[i] = Functor(in[i], scalar[0])
1321   void Right(const Device& d, typename Functor::tout_type out,
1322              typename Functor::tin_type in,
1323              typename Functor::tscalar_type scalar, bool* error);
1324 
1325   // Computes on device "d":
1326   //   out = Functor(in0.broadcast(bcast0), in1.broadcast(bcast1))
1327   //
1328   // TODO(zhifengc): makes BCast a template member function on NDIMS
1329   // instead making BinaryFunctor templates on NDIMS.
1330   void BCast(const Device& d,
1331              typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
1332              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
1333              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
1334              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
1335              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
1336              bool* error);
1337 };
1338 
1339 template <typename Device, typename T>
1340 struct ApproximateEqual {
1341   void operator()(const Device& d, typename TTypes<T>::ConstFlat x,
1342                   typename TTypes<T>::ConstFlat y, T tolerance,
1343                   typename TTypes<bool>::Flat z);
1344 };
1345 
1346 template <int NDIMS>
1347 bool AllOne(const typename Eigen::array<Eigen::DenseIndex, NDIMS>& a) {
1348   for (size_t i = 0; i < a.size(); ++i) {
1349     if (a[i] != 1) return false;
1350   }
1351   return true;
1352 }
1353 
1354 template <typename Device, typename T>
1355 struct SelectFunctor {
1356   void operator()(const Device& d, typename TTypes<T>::Flat out,
1357                   typename TTypes<bool>::ConstFlat cond_flat,
1358                   typename TTypes<T>::ConstFlat then_flat,
1359                   typename TTypes<T>::ConstFlat else_flat);
1360 };
1361 
1362 template <typename Device, typename T>
1363 struct SelectScalarFunctor {
1364   void operator()(const Device& d, typename TTypes<T>::Flat out,
1365                   typename TTypes<bool>::ConstScalar cond,
1366                   typename TTypes<T>::ConstFlat then_flat,
1367                   typename TTypes<T>::ConstFlat else_flat);
1368 };
1369 
1370 template <typename Device, typename T>
1371 struct BatchSelectFunctor {
1372   void operator()(const Device& d,
1373                   typename TTypes<T>::Matrix output_flat_outer_dims,
1374                   TTypes<bool>::ConstVec cond_vec,
1375                   typename TTypes<T>::ConstMatrix then_flat_outer_dims,
1376                   typename TTypes<T>::ConstMatrix else_flat_outer_dims);
1377 };
1378 
1379 template <typename Device, typename T, int NDIMS>
1380 struct BCastSelectFunctor {
1381   void operator()(const Device& d,
1382                   typename TTypes<T, NDIMS>::Tensor output_tensor,
1383                   typename TTypes<bool, NDIMS>::ConstTensor cond_tensor,
1384                   typename TTypes<T, NDIMS>::ConstTensor then_tensor,
1385                   typename TTypes<T, NDIMS>::ConstTensor else_tensor,
1386                   typename Eigen::array<Eigen::DenseIndex, NDIMS> cond_bcast,
1387                   typename Eigen::array<Eigen::DenseIndex, NDIMS> then_bcast,
1388                   typename Eigen::array<Eigen::DenseIndex, NDIMS> else_bcast);
1389 };
1390 
1391 }  // end namespace functor
1392 }  // end namespace tensorflow
1393 
1394 #endif  // TENSORFLOW_CORE_KERNELS_CWISE_OPS_H_
1395