• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_H_
17 #define TENSORFLOW_CORE_KERNELS_CWISE_OPS_H_
18 
19 #include <cmath>
20 #include <functional>
21 #include <type_traits>
22 
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 
25 #include "tensorflow/core/framework/bounds_check.h"
26 #include "tensorflow/core/framework/numeric_types.h"
27 #include "tensorflow/core/framework/tensor_types.h"
28 
29 namespace Eigen {
30 namespace internal {
31 
32 #if GOOGLE_CUDA
33 template <>
34 struct scalar_arg_op<std::complex<float>> {
35   EIGEN_EMPTY_STRUCT_CTOR(scalar_arg_op)
36   typedef typename Eigen::NumTraits<std::complex<float>>::Real result_type;
37   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const float operator()(
38       const std::complex<float>& a) const {
39     return ::atan2f(a.imag(), a.real());
40   }
41 };
42 
43 template <>
44 struct scalar_arg_op<std::complex<double>> {
45   EIGEN_EMPTY_STRUCT_CTOR(scalar_arg_op)
46   typedef typename Eigen::NumTraits<std::complex<double>>::Real result_type;
47   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const double operator()(
48       const std::complex<double>& a) const {
49     return ::atan2(a.imag(), a.real());
50   }
51 };
52 #endif
53 
54 #if EIGEN_HAS_CXX11_MATH == 0
55 template <typename T>
56 struct scalar_asinh_op {
57   EIGEN_EMPTY_STRUCT_CTOR(scalar_asinh_op)
58   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a) const {
59     return std::asinh(a);
60   }
61 };
62 template <typename T>
63 struct functor_traits<scalar_asinh_op<T>> {
64   enum { Cost = 5 * NumTraits<T>::MulCost, PacketAccess = false };
65 };
66 
67 template <typename T>
68 struct scalar_acosh_op {
69   EIGEN_EMPTY_STRUCT_CTOR(scalar_acosh_op)
70   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a) const {
71     return std::acosh(a);
72   }
73 };
74 template <typename T>
75 struct functor_traits<scalar_acosh_op<T>> {
76   enum { Cost = 5 * NumTraits<T>::MulCost, PacketAccess = false };
77 };
78 
79 template <typename T>
80 struct scalar_atanh_op {
81   EIGEN_EMPTY_STRUCT_CTOR(scalar_atanh_op)
82   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a) const {
83     return std::atanh(a);
84   }
85 };
86 template <typename T>
87 struct functor_traits<scalar_atanh_op<T>> {
88   enum { Cost = 5 * NumTraits<T>::MulCost, PacketAccess = false };
89 };
90 #endif
91 
92 template <typename Scalar, typename Exponent>
93 struct safe_scalar_binary_pow_op {
94   static_assert(std::is_integral<Scalar>::value, "Integer type expected");
95   static_assert(std::is_integral<Exponent>::value &&
96                     std::is_signed<Exponent>::value,
97                 "Signed integer type expected");
98 
99   bool* const error;
100 
101   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_scalar_binary_pow_op(bool* error)
102       : error(error) {}
103 
104   EIGEN_DEVICE_FUNC inline Scalar operator()(const Scalar& a,
105                                              const Exponent& b) const {
106     const Exponent safe_b = tensorflow::internal::SubtleMustCopy(b);
107     if (TF_PREDICT_TRUE(safe_b >= 0)) {
108       return numext::pow(a, safe_b);
109     } else {
110       *error = true;
111       return 0;
112     }
113   }
114 };
115 
116 template <typename Scalar, typename Exponent>
117 struct functor_traits<safe_scalar_binary_pow_op<Scalar, Exponent>> {
118   enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false };
119 };
120 
121 template <typename T, typename DivOrMod>
122 struct safe_div_or_mod_op {
123   static_assert(std::is_integral<T>::value, "Integer type expected");
124 
125   bool* const error;
126 
127   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_div_or_mod_op(bool* error)
128       : error(error) {}
129 
130   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a,
131                                                            const T& b) const {
132     const T safe_b = tensorflow::internal::SubtleMustCopy(b);
133     if (TF_PREDICT_TRUE(safe_b != 0)) {
134       return DivOrMod()(a, safe_b);
135     } else {
136       *error = true;
137       return 0;
138     }
139   }
140 };
141 
142 template <typename T, typename DivOrMod>
143 struct functor_traits<safe_div_or_mod_op<T, DivOrMod>> {
144   enum {
145     Cost = functor_traits<DivOrMod>::Cost + NumTraits<T>::AddCost,
146     PacketAccess = false,
147   };
148 };
149 
150 template <typename T, typename Binary>
151 struct no_nan_op {
152   EIGEN_EMPTY_STRUCT_CTOR(no_nan_op)
153   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a,
154                                                            const T& b) const {
155     if (b != T(0)) {
156       return Binary()(a, b);
157     } else {
158       return T(0);
159     }
160   }
161   template <typename Packet>
162   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
163   packetOp(const Packet& a, const Packet& b) const {
164     const Packet mask = pcmp_eq(b, pzero(b));
165     const Packet quotient = Binary().packetOp(a, b);
166     return pandnot(quotient, mask);
167   }
168 };
169 
170 template <typename T>
171 struct div_no_nan_op : public no_nan_op<T, scalar_quotient_op<T>> {
172   EIGEN_EMPTY_STRUCT_CTOR(div_no_nan_op)
173 };
174 
175 template <typename T>
176 struct functor_traits<div_no_nan_op<T>> {
177   enum {
178     Cost = functor_traits<scalar_quotient_op<T>>::Cost + NumTraits<T>::AddCost,
179     PacketAccess = true,
180   };
181 };
182 
183 template <typename T>
184 struct mul_no_nan_op : public no_nan_op<T, scalar_product_op<T>> {
185   EIGEN_EMPTY_STRUCT_CTOR(mul_no_nan_op)
186 };
187 
188 template <typename T>
189 struct functor_traits<mul_no_nan_op<T>> {
190   enum {
191     Cost = functor_traits<scalar_product_op<T>>::Cost + NumTraits<T>::AddCost,
192     PacketAccess = true,
193   };
194 };
195 
196 // scalar_left and scalar_right are template helpers to partially
197 // apply a binary function.
198 //
199 // Suppose Binary is a binary functor f(x, y), scalar_left<> is a
200 // unary functor g_x(y) = f(x, y), where x is provided via the
201 // constructor. Similarly, scalar_right<> is a unary functor g_y(x) =
202 // f(x, y).
203 
204 template <typename Tout, typename Tin, typename Binary>
205 struct scalar_left : private Binary {
206   typedef Tout result_type;
207   const Tin* left;
208 
209   EIGEN_DEVICE_FUNC inline scalar_left(const scalar_left& other) = default;
210 
211   template <typename... Args>
212   EIGEN_DEVICE_FUNC inline explicit scalar_left(const Tin* c, Args... args)
213       : Binary(args...), left(c) {}
214 
215   EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& right) const {
216     return Binary::operator()(*left, right);
217   }
218 
219   template <typename Packet>
220   EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& right_packet) const {
221     const Packet left_packet = Eigen::internal::pset1<Packet>(*left);
222     return Binary::packetOp(left_packet, right_packet);
223   }
224 };
225 
226 template <typename Tout, typename Tin, typename Binary>
227 struct functor_traits<scalar_left<Tout, Tin, Binary>> {
228   enum {
229     Cost = functor_traits<Binary>::Cost,
230     PacketAccess = functor_traits<Binary>::PacketAccess,
231   };
232 };
233 
234 template <typename Tout, typename Tin, typename Binary>
235 struct scalar_right : private Binary {
236   typedef Tout result_type;
237   const Tin* right;
238 
239   EIGEN_DEVICE_FUNC inline scalar_right(const scalar_right& other) = default;
240 
241   template <typename... Args>
242   EIGEN_DEVICE_FUNC inline explicit scalar_right(const Tin* c, Args... args)
243       : Binary(args...), right(c) {}
244 
245   EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& left) const {
246     return Binary::operator()(left, *right);
247   }
248 
249   template <typename Packet>
250   EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& left_packet) const {
251     const Packet right_packet = Eigen::internal::pset1<Packet>(*right);
252     return Binary::packetOp(left_packet, right_packet);
253   }
254 };
255 
256 template <typename Tout, typename Tin, typename Binary>
257 struct functor_traits<scalar_right<Tout, Tin, Binary>> {
258   enum {
259     Cost = functor_traits<Binary>::Cost,
260     PacketAccess = functor_traits<Binary>::PacketAccess,
261   };
262 };
263 
264 // similar to std::equal_to, but with the DEVICE_FUNC qualifier
265 template <class T>
266 struct equal_to : std::binary_function<T, T, bool> {
267   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x,
268                                                         const T& y) const {
269     return x == y;
270   }
271 };
272 
273 // similar to std::not_equal_to, but with the DEVICE_FUNC qualifier
274 template <class T>
275 struct not_equal_to : std::binary_function<T, T, bool> {
276   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x,
277                                                         const T& y) const {
278     return x != y;
279   }
280 };
281 
282 // similar to std::greater, but with the DEVICE_FUNC qualifier
283 template <class T>
284 struct greater : std::binary_function<T, T, bool> {
285   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x,
286                                                         const T& y) const {
287     return x > y;
288   }
289 };
290 
291 // similar to std::less, but with the DEVICE_FUNC qualifier
292 template <class T>
293 struct less : std::binary_function<T, T, bool> {
294   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x,
295                                                         const T& y) const {
296     return x < y;
297   }
298 };
299 
300 // similar to std::greater_equal, but with the DEVICE_FUNC qualifier
301 template <class T>
302 struct greater_equal : std::binary_function<T, T, bool> {
303   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x,
304                                                         const T& y) const {
305     return x >= y;
306   }
307 };
308 
309 // similar to std::less_equal, but with the DEVICE_FUNC qualifier
310 template <class T>
311 struct less_equal : std::binary_function<T, T, bool> {
312   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x,
313                                                         const T& y) const {
314     return x <= y;
315   }
316 };
317 
318 // Functor that enables squared difference functor.
319 template <typename Scalar>
320 struct scalar_squared_difference_op {
321   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar
322   operator()(const Scalar& a, const Scalar& b) const {
323     const Scalar v = scalar_difference_op<Scalar>()(a, b);
324     return scalar_product_op<Scalar>()(v, scalar_conjugate_op<Scalar>()(v));
325   }
326   template <typename Packet>
327   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
328   packetOp(const Packet& a, const Packet& b) const {
329     const Packet v = scalar_difference_op<Scalar>().packetOp(a, b);
330     return scalar_product_op<Scalar>().packetOp(
331         v, scalar_conjugate_op<Scalar>().packetOp(v));
332   }
333 };
334 
335 template <typename Scalar>
336 struct functor_traits<scalar_squared_difference_op<Scalar>> {
337   enum {
338     Cost = functor_traits<scalar_difference_op<Scalar>>::Cost +
339            functor_traits<scalar_conjugate_op<Scalar>>::Cost +
340            functor_traits<scalar_product_op<Scalar>>::Cost,
341     PacketAccess = functor_traits<scalar_difference_op<Scalar>>::PacketAccess &&
342                    functor_traits<scalar_conjugate_op<Scalar>>::PacketAccess &&
343                    functor_traits<scalar_product_op<Scalar>>::PacketAccess
344   };
345 };
346 
347 // TODO(b/32239616): This kernel should be moved into Eigen and vectorized.
348 template <typename T, typename Enable = void>
349 struct google_floor_div {
350   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x,
351                                                            const T& y) const {
352     if ((x < T(0)) != (y < T(0))) {
353       T abs_x = std::abs(x);
354       T abs_y = std::abs(y);
355       return -(abs_x + abs_y - 1) / abs_y;
356     } else {
357       return x / y;
358     }
359   }
360   template <typename Packet>
361   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
362   packetOp(const Packet& x, const Packet& y) const {
363     Packet zeros = pzero(x);
364     Packet x_mask = pcmp_lt(x, zeros);
365     Packet y_mask = pcmp_lt(y, zeros);
366     Packet x_div_y = pdiv(x, y);
367     Packet abs_x = pabs(x);
368     Packet abs_y = pabs(y);
369     Packet ones = pones(x);
370     Packet ratio_rounded = pdiv(pnegate(psub(padd(abs_x, abs_y), ones)), abs_y);
371     return pselect(pxor(x_mask, y_mask), ratio_rounded, x_div_y);
372   }
373 };
374 
375 template <typename T>
376 struct google_floor_div<
377     T, typename std::enable_if<std::is_unsigned<T>::value>::type> {
378   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x,
379                                                            const T& y) const {
380     return x / y;
381   }
382   template <typename Packet>
383   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
384   packetOp(const Packet& x, const Packet& y) const {
385     return pdiv(x, y);
386   }
387 };
388 
389 template <typename Scalar>
390 struct functor_traits<google_floor_div<Scalar>> {
391   enum {
392     Cost = 2 * Eigen::internal::scalar_div_cost<
393                    Scalar, packet_traits<Scalar>::HasDiv>::value +
394            NumTraits<Scalar>::AddCost,
395     PacketAccess = packet_traits<Scalar>::HasDiv
396   };
397 };
398 
399 template <typename T, typename Enable = void>
400 struct google_floor_div_real {
401   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x,
402                                                            const T& y) const {
403     return Eigen::numext::floor(x / y);
404   }
405   template <typename Packet>
406   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
407   packetOp(const Packet& x, const Packet& y) const {
408     return pfloor(pdiv(x, y));
409   }
410 };
411 
412 template <typename Scalar>
413 struct functor_traits<google_floor_div_real<Scalar>> {
414   enum {
415     Cost = 2 * Eigen::internal::scalar_div_cost<
416                    Scalar, packet_traits<Scalar>::HasDiv>::value +
417            2 * NumTraits<Scalar>::AddCost,
418     PacketAccess =
419         packet_traits<Scalar>::HasDiv && packet_traits<Scalar>::HasFloor
420   };
421 };
422 
423 // TODO(rmlarsen): Add vectorized mod & fmod in Eigen and use it here.
424 template <typename T>
425 struct google_floor_fmod {
426   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x,
427                                                            const T& y) const {
428     // EIGEN_STATIC_ASSERT(NUMERIC_TYPE_MUST_BE_REAL);
429     T trunc_mod = std::fmod(x, y);
430     return (x < T(0)) == (y < T(0)) ? trunc_mod : std::fmod(trunc_mod + y, y);
431   }
432 };
433 
434 template <typename Scalar>
435 struct functor_traits<google_floor_fmod<Scalar>> {
436   enum {
437     Cost = 2 * Eigen::internal::scalar_div_cost<Scalar, false>::value +
438            2 * NumTraits<Scalar>::AddCost,
439     PacketAccess = false
440   };
441 };
442 
443 // TODO(rmlarsen): Add vectorized mod & fmod in Eigen and use it here.
444 template <typename T>
445 struct google_floor_mod {
446   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x,
447                                                            const T& y) const {
448     // EIGEN_STATIC_ASSERT(!NUMERIC_TYPE_MUST_BE_REAL);
449     T trunc_mod = x % y;
450     return (x < T(0)) == (y < T(0)) ? trunc_mod : (trunc_mod + y) % y;
451   }
452 };
453 
454 template <typename Scalar>
455 struct functor_traits<google_floor_mod<Scalar>> {
456   enum {
457     Cost = 2 * Eigen::internal::scalar_div_cost<Scalar, false>::value +
458            2 * NumTraits<Scalar>::AddCost,
459     PacketAccess = false
460   };
461 };
462 
463 #if EIGEN_COMP_GNUC && __cplusplus > 199711L
464 #define DISABLE_FLOAT_EQUALITY_WARNING \
465   _Pragma("GCC diagnostic push")       \
466       _Pragma("GCC diagnostic ignored \"-Wfloat-equal\"")
467 #define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop")
468 #else
469 #define DISABLE_FLOAT_EQUALITY_WARNING
470 #define ENABLE_FLOAT_EQUALITY_WARNING
471 #endif
472 
473 template <typename Scalar, bool IsInteger = Eigen::NumTraits<Scalar>::IsInteger>
474 struct scalar_round_op_google {
475   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar
476   operator()(const Scalar& x) const {
477     EIGEN_STATIC_ASSERT((!NumTraits<Scalar>::IsComplex),
478                         NUMERIC_TYPE_MUST_BE_REAL)
479 
480     Scalar round_val = Eigen::numext::floor(x);
481     const Scalar fraction = x - round_val;
482     if (fraction > Scalar(.5)) {
483       round_val += Scalar(1.0);
484     } else if (fraction == Scalar(.5)) {
485       const Scalar nearest_even_int =
486           round_val - Scalar(2) * Eigen::numext::floor(Scalar(.5) * x);
487       bool is_odd = (nearest_even_int == Scalar(1));
488       if (is_odd) {
489         round_val += Scalar(1);
490       }
491     }
492     return round_val;
493   }
494 };
495 
496 template <typename Scalar>
497 struct scalar_round_op_google<Scalar, true> {
498   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar
499   operator()(const Scalar& x) const {
500     return x;
501   }
502   template <typename Packet>
503   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
504   packetOp(const Packet& x) const {
505     return x;
506   }
507 };
508 
509 template <typename Scalar>
510 struct functor_traits<scalar_round_op_google<Scalar>> {
511   enum {
512     Cost = Eigen::NumTraits<Scalar>::IsInteger ? 0
513                                                : 4 * NumTraits<Scalar>::AddCost,
514     PacketAccess = Eigen::NumTraits<Scalar>::IsInteger
515   };
516 };
517 
518 template <typename Scalar, bool IsInteger = Eigen::NumTraits<Scalar>::IsInteger>
519 struct scalar_round_up_op {
520   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar
521   operator()(const Scalar& x) const {
522     EIGEN_STATIC_ASSERT((!NumTraits<Scalar>::IsComplex),
523                         NUMERIC_TYPE_MUST_BE_REAL)
524     return Eigen::numext::floor(x + Scalar(0.5));
525   }
526 
527   template <typename Packet>
528   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
529   packetOp(const Packet& x) const {
530     return pfloor(padd(x, pset1<Packet>(0.5)));
531   }
532 };
533 
534 template <typename Scalar>
535 struct scalar_round_up_op<Scalar, true> {
536   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar
537   operator()(const Scalar& x) const {
538     return x;
539   }
540 
541   template <typename Packet>
542   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
543   packetOp(const Packet& x) const {
544     return x;
545   }
546 };
547 
548 template <typename Scalar, bool IsInteger>
549 struct functor_traits<scalar_round_up_op<Scalar, IsInteger>> {
550   enum {
551     Cost = IsInteger ? 0 : 4 * NumTraits<Scalar>::AddCost,
552     PacketAccess = IsInteger || packet_traits<Scalar>::HasFloor
553   };
554 };
555 
556 #undef ENABLE_FLOAT_EQUALITY_WARNING
557 #undef DISABLE_FLOAT_EQUALITY_WARNING
558 
559 template <typename Scalar>
560 struct bitwise_xor_op {
561   EIGEN_EMPTY_STRUCT_CTOR(bitwise_xor_op)
562   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar
563   operator()(const Scalar& x, const Scalar& y) const {
564     return x ^ y;
565   }
566   template <typename Packet>
567   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
568   packetOp(const Packet& a, const Packet& b) const {
569     return Eigen::internal::pxor(a, b);
570   }
571 };
572 
573 template <typename Scalar>
574 struct functor_traits<bitwise_xor_op<Scalar>> {
575   enum { Cost = Eigen::NumTraits<Scalar>::AddCost, PacketAccess = true };
576 };
577 
578 template <typename Scalar>
579 struct xlogy_op {
580   EIGEN_EMPTY_STRUCT_CTOR(xlogy_op)
581   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar
582   operator()(const Scalar& x, const Scalar& y) const {
583     if (x == Scalar(0.)) {
584       return Scalar(0.);
585     }
586     return x * numext::log(y);
587   }
588   template <typename Packet>
589   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
590   packetOp(const Packet& x, const Packet& y) const {
591     Packet zeros = pzero(x);
592     Packet mask = pcmp_eq(x, zeros);
593     scalar_log_op<Scalar> log_op;
594     Packet log_y = log_op.packetOp(y);
595     Packet x_log_y = pmul(x, log_y);
596     return pselect(mask, x, x_log_y);
597   }
598 };
599 
600 template <typename Scalar>
601 struct functor_traits<xlogy_op<Scalar>> {
602   enum {
603     Cost = functor_traits<scalar_log_op<Scalar>>::Cost +
604            Eigen::NumTraits<Scalar>::MulCost,
605     PacketAccess = functor_traits<scalar_log_op<Scalar>>::PacketAccess
606   };
607 };
608 
609 template <typename Scalar>
610 struct xdivy_op {
611   EIGEN_EMPTY_STRUCT_CTOR(xdivy_op)
612   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar
613   operator()(const Scalar& x, const Scalar& y) const {
614     if (x == Scalar(0.)) {
615       return Scalar(0.);
616     }
617     return x / y;
618   }
619   template <typename Packet>
620   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
621   packetOp(const Packet& x, const Packet& y) const {
622     Packet zeros = pzero(x);
623     Packet mask = pcmp_eq(x, zeros);
624     Packet x_div_y = pdiv(x, y);
625     return pselect(mask, x, x_div_y);
626   }
627 };
628 
629 template <typename Scalar>
630 struct functor_traits<xdivy_op<Scalar>> {
631   enum {
632     Cost =
633         Eigen::NumTraits<Scalar>::AddCost +
634         Eigen::internal::scalar_div_cost<Scalar,
635                                          packet_traits<Scalar>::HasDiv>::value,
636     PacketAccess = packet_traits<Scalar>::HasDiv
637   };
638 };
639 
640 }  // end namespace internal
641 }  // end namespace Eigen
642 
643 namespace tensorflow {
644 namespace functor {
645 
646 ////////////////////////////////////////////////////////////////////////////////
647 // Helpers
648 ////////////////////////////////////////////////////////////////////////////////
649 
650 // Base template for functors whose input scalar type is T and
651 // output scalar type is R.
652 template <typename T, typename F, typename R = T>
653 struct base {
654   // func defines operator() and its vectorized version packetOp().
655   typedef F func;
656 
657   // If true, the functor's corresponding binary op will instantiate
658   // specialized kernels to perform an optimized broadcast
659   // operation. Each functor for which this is enabled increases the
660   // code size, so by default this is disabled for binary functors and
661   // is enabled on a per-op basis as needed.
662   static const bool use_bcast_optimization = false;
663 
664   // operator() has the signature:
665   //  out_type operator()(in_type in0, in_type in1 ...)
666   typedef R out_type;
667   typedef T in_type;
668 
669   // TensorFlow provides tensor-ized version of "func". Roughly
670   // speaking, the tensorflow operation has the signature:
671   //   tout_type op(tin_type in0)
672   //   tout_type op(tin_type in0, tin_type in1)
673   //   tout_type op(tin_type in0, in_type scalar)
674   typedef typename TTypes<out_type>::Flat tout_type;
675   typedef typename TTypes<in_type>::ConstFlat tin_type;
676   typedef typename TTypes<in_type>::ConstScalar tscalar_type;
677 
678   // Whether the functor can error out.  Currently applies only to integer
679   // div and mod.
680   static const bool has_errors = false;
681 };
682 
683 // For now, we only apply certain speed optimization for
684 // float/double's broadcast binary op.
685 template <typename T>
686 struct use_bcast_optimization {
687   static const bool value = false;
688 };
689 
690 template <>
691 struct use_bcast_optimization<float> {
692   static const bool value = true;
693 };
694 
695 template <>
696 struct use_bcast_optimization<double> {
697   static const bool value = true;
698 };
699 
700 ////////////////////////////////////////////////////////////////////////////////
701 // Unary functors
702 ////////////////////////////////////////////////////////////////////////////////
703 
704 // abs(x) = |x|
705 // neg(x) = - x
706 // inverse(x) = 1 / x
707 // square(x) = x^2
708 // sqrt(x) = x^(1/2)
709 // rsqrt(x) = x^(-1/2)
710 // exp(x) = e^x
711 // expm1(x) = e^x - 1
712 // log(x) = natural logarithm of x
713 // log1p(x) = natural logarithm of 1 + x
714 // tanh = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
715 // sigmoid = 1 / (1 + exp(-x))  // a.k.a, logistic
716 //
717 // NOTE: We may eventually implement common functions used in NN
718 // here. E.g., rectifier, softplus, derivatives of tanh, sigmod, etc.
719 // For reference, see speech/lstm/eigen_functors.h.
720 
721 template <typename T>
722 struct abs : base<T, Eigen::internal::scalar_abs_op<T>,
723                   typename Eigen::internal::scalar_abs_op<T>::result_type> {};
724 
725 template <typename T>
726 struct neg : base<T, Eigen::internal::scalar_opposite_op<T>> {};
727 
728 template <typename T>
729 struct inverse : base<T, Eigen::internal::scalar_inverse_op<T>> {};
730 
731 template <typename T>
732 struct square : base<T, Eigen::internal::scalar_square_op<T>> {};
733 
734 template <typename T>
735 struct sqrt : base<T, Eigen::internal::scalar_sqrt_op<T>> {};
736 
737 template <typename T>
738 struct rsqrt : base<T, Eigen::internal::scalar_rsqrt_op<T>> {};
739 
740 template <typename T>
741 struct exp : base<T, Eigen::internal::scalar_exp_op<T>> {};
742 
743 template <typename T>
744 struct expm1 : base<T, Eigen::internal::scalar_expm1_op<T>> {};
745 
746 template <typename T>
747 struct log : base<T, Eigen::internal::scalar_log_op<T>> {};
748 
749 template <typename T>
750 struct log1p : base<T, Eigen::internal::scalar_log1p_op<T>> {};
751 
752 template <typename T>
753 struct sign : base<T, Eigen::internal::scalar_sign_op<T>> {};
754 
755 template <typename T>
756 struct sinh : base<T, Eigen::internal::scalar_sinh_op<T>> {};
757 
758 template <typename T>
759 struct cosh : base<T, Eigen::internal::scalar_cosh_op<T>> {};
760 
761 template <typename T>
762 struct tanh : base<T, Eigen::internal::scalar_tanh_op<T>> {};
763 
764 template <typename T>
765 struct asinh : base<T, Eigen::internal::scalar_asinh_op<T>> {};
766 
767 template <typename T>
768 struct acosh : base<T, Eigen::internal::scalar_acosh_op<T>> {};
769 
770 template <typename T>
771 struct atanh : base<T, Eigen::internal::scalar_atanh_op<T>> {};
772 
773 template <typename T>
774 struct lgamma : base<T, Eigen::internal::scalar_lgamma_op<T>> {};
775 
776 template <typename T>
777 struct digamma : base<T, Eigen::internal::scalar_digamma_op<T>> {};
778 
779 template <typename T>
780 struct erf : base<T, Eigen::internal::scalar_erf_op<T>> {};
781 
782 template <typename T>
783 struct erfc : base<T, Eigen::internal::scalar_erfc_op<T>> {};
784 
785 template <typename T>
786 struct sigmoid : base<T, Eigen::internal::scalar_logistic_op<T>> {};
787 
788 template <typename T>
789 struct sin : base<T, Eigen::internal::scalar_sin_op<T>> {};
790 
791 template <typename T>
792 struct cos : base<T, Eigen::internal::scalar_cos_op<T>> {};
793 
794 template <typename T>
795 struct tan : base<T, Eigen::internal::scalar_tan_op<T>> {};
796 
797 template <typename T>
798 struct asin : base<T, Eigen::internal::scalar_asin_op<T>> {};
799 
800 template <typename T>
801 struct acos : base<T, Eigen::internal::scalar_acos_op<T>> {};
802 
803 template <typename T>
804 struct atan : base<T, Eigen::internal::scalar_atan_op<T>> {};
805 
806 template <typename T>
807 struct bessel_i0e : base<T, Eigen::internal::scalar_i0e_op<T>> {};
808 
809 template <typename T>
810 struct bessel_i1e : base<T, Eigen::internal::scalar_i1e_op<T>> {};
811 
812 struct logical_not : base<bool, Eigen::internal::scalar_boolean_not_op<bool>> {
813 };
814 
815 // Flip all bits. Named invert to be consistent with numpy.
816 template <typename T>
817 struct invert_op {
818   EIGEN_EMPTY_STRUCT_CTOR(invert_op)
819   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a) const {
820     return ~a;
821   }
822 };
823 
824 template <typename T>
825 struct invert : base<T, invert_op<T>> {};
826 
827 // NOTE: std::isinf, std::isnan, std::isfinite are plain function.
828 // Therefore we need to wrap them in functors to be used with Eigen's
829 // type system.
830 template <typename T>
831 struct isinf : base<T, Eigen::internal::scalar_isinf_op<T>, bool> {};
832 
833 template <typename T>
834 struct isnan : base<T, Eigen::internal::scalar_isnan_op<T>, bool> {};
835 
836 template <typename T>
837 struct isfinite : base<T, Eigen::internal::scalar_isfinite_op<T>, bool> {};
838 
839 template <typename T>
840 struct floor : base<T, Eigen::internal::scalar_floor_op<T>> {};
841 
842 template <typename T>
843 struct round : base<T, Eigen::internal::scalar_round_op_google<T>> {};
844 
845 template <typename T>
846 struct ceil : base<T, Eigen::internal::scalar_ceil_op<T>> {};
847 
848 /** this should go in Eigen
849  * \brief Template functor to compute the round to int value of a scalar
850  */
851 template <typename Scalar>
852 struct scalar_rint_op {
853   EIGEN_EMPTY_STRUCT_CTOR(scalar_rint_op)
854   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar
855   operator()(const Scalar& a) const {
856 #if defined(__CUDACC__)
857     return ::rint(a);
858 #elif defined(__ANDROID__)
859     return rint(a);
860 #else
861     return std::rint(a);
862 #endif
863   }
864 };
865 
866 template <typename T>
867 struct rint : base<T, scalar_rint_op<T>> {};
868 
869 ////////////////////////////////////////////////////////////////////////////////
870 // Binary functors
871 ////////////////////////////////////////////////////////////////////////////////
872 
873 // Binary functors:
874 //
875 // add(x, y) = x + y
876 // sub(x, y) = x - y
877 // mul(x, y) = x * y
878 // div(x, y) = x / y
879 // mod(x, y) = x % y         (int32 and int64 only)
880 // fmod(x, y) = fmod(x, y)   (float and double only)
881 // pow(x, y) = x ^ y
882 // maximum(x, y) = x > y ? x : y
883 // minimum(x, y) = x < y ? x : y
884 // squared_difference(x, y) = conj(x - y) * (x - y)
885 
886 template <typename T>
887 struct add : base<T, Eigen::internal::scalar_sum_op<T>> {
888   static const bool use_bcast_optimization = true;
889 };
890 
891 template <typename T>
892 struct sub : base<T, Eigen::internal::scalar_difference_op<T>> {
893   static const bool use_bcast_optimization = true;
894 };
895 
896 template <typename T>
897 struct mul : base<T, Eigen::internal::scalar_product_op<T>> {
898   static const bool use_bcast_optimization = true;
899 };
900 
901 template <typename T>
902 struct mul_no_nan : base<T, Eigen::internal::mul_no_nan_op<T>> {};
903 
904 template <typename T>
905 struct div : base<T, Eigen::internal::scalar_quotient_op<T>> {};
906 
907 template <typename T>
908 struct safe_div : base<T, Eigen::internal::safe_div_or_mod_op<
909                               T, Eigen::internal::scalar_quotient_op<T>>> {
910   static const bool has_errors = true;
911 };
912 
913 template <typename T>
914 struct div_no_nan : base<T, Eigen::internal::div_no_nan_op<T>> {};
915 
916 template <typename T>
917 struct fmod : base<T, Eigen::internal::scalar_fmod_op<T>> {};
918 
919 template <typename T>
920 struct mod : base<T, Eigen::internal::scalar_mod2_op<T>> {};
921 
922 template <typename T>
923 struct safe_mod : base<T, Eigen::internal::safe_div_or_mod_op<
924                               T, Eigen::internal::scalar_mod2_op<T>>> {
925   static const bool has_errors = true;
926 };
927 
928 template <typename T>
929 struct floor_fmod : base<T, Eigen::internal::google_floor_fmod<T>> {};
930 
931 template <typename T>
932 struct safe_floor_mod : base<T, Eigen::internal::safe_div_or_mod_op<
933                                     T, Eigen::internal::google_floor_mod<T>>> {
934   static const bool has_errors = true;
935 };
936 
937 template <typename T>
938 struct floor_div : base<T, Eigen::internal::google_floor_div<T>> {};
939 
940 template <typename T>
941 struct safe_floor_div : base<T, Eigen::internal::safe_div_or_mod_op<
942                                     T, Eigen::internal::google_floor_div<T>>> {
943   static const bool has_errors = true;
944 };
945 
946 template <typename T>
947 struct floor_div_real : base<T, Eigen::internal::google_floor_div_real<T>> {};
948 
949 template <typename T>
950 struct pow : base<T, Eigen::internal::scalar_pow_op<T, T>> {};
951 
952 template <typename T>
953 struct safe_pow : base<T, Eigen::internal::safe_scalar_binary_pow_op<T, T>> {
954   static const bool has_errors = true;
955 };
956 
957 template <typename T>
958 struct maximum : base<T, Eigen::internal::scalar_max_op<T>> {};
959 
960 template <typename T>
961 struct minimum : base<T, Eigen::internal::scalar_min_op<T>> {};
962 
963 template <typename T>
964 struct igamma : base<T, Eigen::internal::scalar_igamma_op<T>> {};
965 
966 template <typename T>
967 struct random_gamma_grad
968     : base<T, Eigen::internal::scalar_gamma_sample_der_alpha_op<T>> {};
969 
970 template <typename T>
971 struct igammac : base<T, Eigen::internal::scalar_igammac_op<T>> {};
972 
973 template <typename T>
974 struct zeta : base<T, Eigen::internal::scalar_zeta_op<T>> {};
975 
976 template <typename T>
977 struct polygamma : base<T, Eigen::internal::scalar_polygamma_op<T>> {};
978 
979 template <typename Scalar>
980 struct scalar_atan2_op {
981   EIGEN_EMPTY_STRUCT_CTOR(scalar_atan2_op)
982   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
983   operator()(const Scalar& y, const Scalar& x) const {
984 #if GOOGLE_CUDA
985     return ::atan2(y, x);
986 #else
987     return std::atan2(y, x);
988 #endif
989   }
990 };
991 
992 template <typename T>
993 struct atan2 : base<T, scalar_atan2_op<T>> {};
994 
995 template <typename T>
996 struct squared_difference
997     : base<T, Eigen::internal::scalar_squared_difference_op<T>> {};
998 
999 template <typename T>
1000 struct xdivy : base<T, Eigen::internal::xdivy_op<T>> {};
1001 
1002 template <typename T>
1003 struct xlogy : base<T, Eigen::internal::xlogy_op<T>> {};
1004 
1005 template <typename T>
1006 struct less : base<T, Eigen::internal::less<T>, bool> {};
1007 
1008 template <typename T>
1009 struct less_equal : base<T, Eigen::internal::less_equal<T>, bool> {};
1010 
1011 template <typename T>
1012 struct greater : base<T, Eigen::internal::greater<T>, bool> {};
1013 
1014 template <typename T>
1015 struct greater_equal : base<T, Eigen::internal::greater_equal<T>, bool> {};
1016 
1017 template <typename T>
1018 struct equal_to : base<T, Eigen::internal::equal_to<T>, bool> {};
1019 
1020 template <typename T>
1021 struct not_equal_to : base<T, Eigen::internal::not_equal_to<T>, bool> {};
1022 
1023 struct logical_and : base<bool, Eigen::internal::scalar_boolean_and_op> {};
1024 
1025 struct logical_or : base<bool, Eigen::internal::scalar_boolean_or_op> {};
1026 
1027 template <typename T>
1028 struct bitwise_and_op {
1029   EIGEN_EMPTY_STRUCT_CTOR(bitwise_and_op)
1030   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x,
1031                                                            const T& y) const {
1032     return x & y;
1033   }
1034 };
1035 
1036 template <typename T>
1037 struct bitwise_or_op {
1038   EIGEN_EMPTY_STRUCT_CTOR(bitwise_or_op)
1039   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x,
1040                                                            const T& y) const {
1041     return x | y;
1042   }
1043 };
1044 
1045 template <typename T>
1046 struct bitwise_and : base<T, bitwise_and_op<T>> {};
1047 
1048 template <typename T>
1049 struct bitwise_or : base<T, bitwise_or_op<T>> {};
1050 
1051 template <typename T>
1052 struct bitwise_xor : base<T, Eigen::internal::bitwise_xor_op<T>> {};
1053 
1054 template <typename T>
1055 struct left_shift_op {
1056   EIGEN_EMPTY_STRUCT_CTOR(left_shift_op)
1057   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x,
1058                                                            const T& y) const {
1059     // Avoids UB: don't shift by larger than the bitwidth of T, and
1060     // performs left shifts as unsigned shifts.
1061     T y_clamped = y;
1062     if (y_clamped < 0) {
1063       y_clamped = 0;
1064     } else if (y_clamped > sizeof(T) * CHAR_BIT - 1) {
1065       y_clamped = sizeof(T) * CHAR_BIT - 1;
1066     }
1067     using U = typename std::make_unsigned<T>::type;
1068     return static_cast<T>(static_cast<U>(x) << static_cast<U>(y_clamped));
1069   }
1070 };
1071 
1072 template <typename T>
1073 struct right_shift_op {
1074   EIGEN_EMPTY_STRUCT_CTOR(right_shift_op)
1075   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x,
1076                                                            const T& y) const {
1077     // Avoids UB: don't shift by larger than the bitwidth of T.
1078     T y_clamped = y;
1079     if (y_clamped < 0) {
1080       y_clamped = 0;
1081     } else if (y_clamped > sizeof(T) * CHAR_BIT - 1) {
1082       y_clamped = sizeof(T) * CHAR_BIT - 1;
1083     }
1084     // Technically right shifts of signed integers are not necessarily
1085     // arithmetic shifts according to the C++ standard. However in practice most
1086     // implementations are arithmetic shifts. If this proves to be a problem in
1087     // practice, we may need to use an alternative implementation.
1088     return x >> y_clamped;
1089   }
1090 };
1091 
1092 template <typename T>
1093 struct left_shift : base<T, left_shift_op<T>> {};
1094 
1095 template <typename T>
1096 struct right_shift : base<T, right_shift_op<T>> {};
1097 
1098 template <typename T>
1099 struct make_complex_func {
1100   typedef std::complex<T> result_type;
1101   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(T real,
1102                                                                T imag) const {
1103     return std::complex<T>(real, imag);
1104   }
1105 };
1106 
1107 template <typename T>
1108 struct make_complex : base<T, make_complex_func<T>, std::complex<T>> {};
1109 
1110 template <typename T>
1111 struct get_real
1112     : base<T, Eigen::internal::scalar_real_op<T>, typename T::value_type> {};
1113 
1114 template <typename T>
1115 struct get_imag
1116     : base<T, Eigen::internal::scalar_imag_op<T>, typename T::value_type> {};
1117 
1118 template <typename T>
1119 struct get_angle
1120     : base<T, Eigen::internal::scalar_arg_op<T>, typename T::value_type> {};
1121 
1122 template <typename T>
1123 struct conj : base<T, Eigen::internal::scalar_conjugate_op<T>> {};
1124 
1125 ////////////////////////////////////////////////////////////////////////////////
1126 // Functors takes 1 or 2 tensors, computes the base functor on
1127 // coefficient of the input tensors and puts the results in the output
1128 // tensor.
1129 ////////////////////////////////////////////////////////////////////////////////
1130 template <typename Device, typename Functor>
1131 struct UnaryFunctor {
1132   // Computes on device "d": out[i] = Functor(in[i])
1133   void operator()(const Device& d, typename Functor::tout_type out,
1134                   typename Functor::tin_type in);
1135 };
1136 
1137 template <typename Device, typename Functor, int NDIMS,
1138           bool has_errors = Functor::has_errors>
1139 struct BinaryFunctor {
1140   // Computes on device "d": out[i] = Functor(in0[i], in1[i])
1141   void operator()(const Device& d, typename Functor::tout_type out,
1142                   typename Functor::tin_type in0,
1143                   typename Functor::tin_type in1, bool* error);
1144 
1145   // Computes on device "d": out[i] = Functor(scalar[0], in[i])
1146   void Left(const Device& d, typename Functor::tout_type out,
1147             typename Functor::tscalar_type scalar,
1148             typename Functor::tin_type in, bool* error);
1149 
1150   // Computes on device "d": out[i] = Functor(in[i], scalar[0])
1151   void Right(const Device& d, typename Functor::tout_type out,
1152              typename Functor::tin_type in,
1153              typename Functor::tscalar_type scalar, bool* error);
1154 
1155   // Computes on device "d":
1156   //   out = Functor(in0.broadcast(bcast0), in1.broadcast(bcast1))
1157   //
1158   // TODO(zhifengc): makes BCast a template member function on NDIMS
1159   // instead making BinaryFunctor templates on NDIMS.
1160   void BCast(const Device& d,
1161              typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
1162              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
1163              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
1164              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
1165              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
1166              bool* error);
1167 };
1168 
1169 template <typename Device, typename T>
1170 struct ApproximateEqual {
1171   void operator()(const Device& d, typename TTypes<T>::ConstFlat x,
1172                   typename TTypes<T>::ConstFlat y, T tolerance,
1173                   typename TTypes<bool>::Flat z);
1174 };
1175 
1176 template <int NDIMS>
1177 bool AllOne(const typename Eigen::array<Eigen::DenseIndex, NDIMS>& a) {
1178   for (size_t i = 0; i < a.size(); ++i) {
1179     if (a[i] != 1) return false;
1180   }
1181   return true;
1182 }
1183 
1184 template <typename Device, typename T>
1185 struct SelectFunctor {
1186   void operator()(const Device& d, typename TTypes<T>::Flat out,
1187                   typename TTypes<bool>::ConstFlat cond_flat,
1188                   typename TTypes<T>::ConstFlat then_flat,
1189                   typename TTypes<T>::ConstFlat else_flat);
1190 };
1191 
1192 template <typename Device, typename T>
1193 struct SelectScalarFunctor {
1194   void operator()(const Device& d, typename TTypes<T>::Flat out,
1195                   typename TTypes<bool>::ConstScalar cond,
1196                   typename TTypes<T>::ConstFlat then_flat,
1197                   typename TTypes<T>::ConstFlat else_flat);
1198 };
1199 
1200 template <typename Device, typename T>
1201 struct BatchSelectFunctor {
1202   void operator()(const Device& d,
1203                   typename TTypes<T>::Matrix output_flat_outer_dims,
1204                   TTypes<bool>::ConstVec cond_vec,
1205                   typename TTypes<T>::ConstMatrix then_flat_outer_dims,
1206                   typename TTypes<T>::ConstMatrix else_flat_outer_dims);
1207 };
1208 
1209 }  // end namespace functor
1210 }  // end namespace tensorflow
1211 
1212 #endif  // TENSORFLOW_CORE_KERNELS_CWISE_OPS_H_
1213