1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_H_ 17 #define TENSORFLOW_CORE_KERNELS_CWISE_OPS_H_ 18 19 #define _USE_MATH_DEFINES 20 #include <cmath> 21 #include <functional> 22 #include <type_traits> 23 24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 25 #include "tensorflow/core/framework/bounds_check.h" 26 #include "tensorflow/core/framework/numeric_types.h" 27 #include "tensorflow/core/framework/tensor_types.h" 28 29 namespace Eigen { 30 namespace internal { 31 32 #if GOOGLE_CUDA 33 template <> 34 struct scalar_arg_op<std::complex<float>> { 35 typedef typename Eigen::NumTraits<std::complex<float>>::Real result_type; 36 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()( 37 const std::complex<float>& a) const { 38 return ::atan2f(a.imag(), a.real()); 39 } 40 }; 41 42 template <> 43 struct scalar_arg_op<std::complex<double>> { 44 typedef typename Eigen::NumTraits<std::complex<double>>::Real result_type; 45 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double operator()( 46 const std::complex<double>& a) const { 47 return ::atan2(a.imag(), a.real()); 48 } 49 }; 50 #endif 51 52 #if EIGEN_HAS_CXX11_MATH == 0 53 template <typename T> 54 struct scalar_asinh_op { 55 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a) const { 56 return static_cast<T>(std::asinh(a)); 57 } 58 }; 59 template <typename T> 60 struct functor_traits<scalar_asinh_op<T>> { 61 enum { Cost = 5 * NumTraits<T>::MulCost, PacketAccess = false }; 62 }; 63 64 template <typename T> 65 struct scalar_acosh_op { 66 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a) const { 67 return static_cast<T>(std::acosh(a)); 68 } 69 }; 70 template <typename T> 71 struct functor_traits<scalar_acosh_op<T>> { 72 enum { Cost = 5 * NumTraits<T>::MulCost, PacketAccess = false }; 73 }; 74 75 template <typename T> 76 struct scalar_atanh_op { 77 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a) const { 78 return static_cast<T>(std::atanh(a)); 79 } 80 }; 81 template <typename T> 82 struct functor_traits<scalar_atanh_op<T>> { 83 enum { Cost = 5 * NumTraits<T>::MulCost, PacketAccess = false }; 84 }; 85 #endif 86 87 template <typename Scalar, typename Exponent> 88 struct safe_scalar_binary_pow_op { 89 static_assert(std::is_integral<Scalar>::value, "Integer type expected"); 90 static_assert(std::is_integral<Exponent>::value && 91 std::is_signed<Exponent>::value, 92 "Signed integer type expected"); 93 94 bool* const error; 95 96 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_scalar_binary_pow_op(bool* error) 97 : error(error) {} 98 99 EIGEN_DEVICE_FUNC inline Scalar operator()(const Scalar& a, 100 const Exponent& b) const { 101 const Exponent safe_b = tensorflow::internal::SubtleMustCopy(b); 102 if (TF_PREDICT_TRUE(safe_b >= 0)) { 103 return numext::pow(a, safe_b); 104 } else { 105 *error = true; 106 return 0; 107 } 108 } 109 }; 110 111 template <typename Scalar, typename Exponent> 112 struct functor_traits<safe_scalar_binary_pow_op<Scalar, Exponent>> { 113 enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false }; 114 }; 115 116 template <typename T, typename DivOrMod> 117 struct safe_div_or_mod_op { 118 static_assert(std::is_integral<T>::value, "Integer type expected"); 119 120 bool* const error; 121 122 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_div_or_mod_op(bool* error) 123 : error(error) {} 124 125 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a, 126 const T& b) const { 127 const T safe_b = tensorflow::internal::SubtleMustCopy(b); 128 if (TF_PREDICT_TRUE(safe_b != 0)) { 129 // Avoid FPE for INT_MIN/-1. 130 const T safe_a = tensorflow::internal::SubtleMustCopy(a); 131 if (TF_PREDICT_FALSE(std::is_signed<T>::value && 132 safe_a == std::numeric_limits<T>::min() && 133 safe_b == T(-1))) { 134 // Prefer to overflow 'a' instead of crashing. 135 return DivOrMod()(-safe_a, 1); 136 } 137 return DivOrMod()(safe_a, safe_b); 138 } else { 139 *error = true; 140 return 0; 141 } 142 } 143 }; 144 145 template <typename T, typename DivOrMod> 146 struct functor_traits<safe_div_or_mod_op<T, DivOrMod>> { 147 enum { 148 Cost = functor_traits<DivOrMod>::Cost + NumTraits<T>::AddCost, 149 PacketAccess = false, 150 }; 151 }; 152 153 template <typename T, typename Binary> 154 struct no_nan_op { 155 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a, 156 const T& b) const { 157 if (b != T(0)) { 158 return Binary()(a, b); 159 } else { 160 return T(0); 161 } 162 } 163 template <typename Packet> 164 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, 165 const Packet& b) const { 166 const Packet mask = pcmp_eq(b, pzero(b)); 167 const Packet quotient = Binary().packetOp(a, b); 168 return pandnot(quotient, mask); 169 } 170 }; 171 172 template <typename T, bool IsComplex = Eigen::NumTraits<T>::IsComplex> 173 struct div_no_nan_op; 174 175 template <typename T> 176 struct div_no_nan_op<T, /*IsComplex=*/false> 177 : public no_nan_op<T, scalar_quotient_op<T>> { 178 }; 179 180 template <typename T> 181 struct functor_traits<div_no_nan_op<T, /*IsComplex=*/false>> { 182 enum { 183 Cost = functor_traits<scalar_quotient_op<T>>::Cost + NumTraits<T>::AddCost, 184 PacketAccess = true, 185 }; 186 }; 187 188 // Whether or not complex division produces a NaN depends on the underlying 189 // implementation. Some compilers (e.g. gcc) use a simple method that divides 190 // by |b|^2, which may underflow to 0 for b != 0. 191 template <typename T> 192 struct div_no_nan_op<T, /*IsComplex=*/true> { 193 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a, 194 const T& b) const { 195 if (b == T(0)) { 196 return T(0); 197 } else { 198 // If the numerator is zero, then the result must be zero even if |b|^2 199 // underflows to zero. 200 const T numerator = 201 scalar_product_op<T>()(a, scalar_conjugate_op<T>()(b)); 202 if (numerator == T(0)) { 203 return T(0); 204 } 205 } 206 return scalar_quotient_op<T>()(a, b); 207 } 208 template <typename Packet> 209 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, 210 const Packet& b) const { 211 const Packet numerator = pmul(a, pconj(b)); 212 const Packet mask = por(pcmp_eq(b, pzero(a)), pcmp_eq(numerator, pzero(a))); 213 const Packet quotient = pdiv(a, b); 214 return pandnot(quotient, mask); 215 } 216 }; 217 218 template <typename T> 219 struct functor_traits<div_no_nan_op<T, /*IsComplex=*/true>> { 220 enum { 221 Cost = functor_traits<scalar_quotient_op<T>>::Cost + NumTraits<T>::MulCost, 222 PacketAccess = packet_traits<T>::HasMul && packet_traits<T>::HasDiv && 223 packet_traits<T>::HasConj, 224 }; 225 }; 226 227 template <typename T> 228 struct mul_no_nan_op : public no_nan_op<T, scalar_product_op<T>> { 229 }; 230 231 template <typename T> 232 struct functor_traits<mul_no_nan_op<T>> { 233 enum { 234 Cost = functor_traits<scalar_product_op<T>>::Cost + NumTraits<T>::AddCost, 235 PacketAccess = true, 236 }; 237 }; 238 239 // scalar_left and scalar_right are template helpers to partially 240 // apply a binary function. 241 // 242 // Suppose Binary is a binary functor f(x, y), scalar_left<> is a 243 // unary functor g_x(y) = f(x, y), where x is provided via the 244 // constructor. Similarly, scalar_right<> is a unary functor g_y(x) = 245 // f(x, y). 246 247 template <typename Tout, typename Tin, typename Binary, 248 bool is_scalar_in_host_memory = false> 249 struct scalar_left : private Binary { 250 using result_type = Tout; 251 using TinPacket = typename Eigen::internal::packet_traits<Tin>::type; 252 253 const Tin* left; 254 TinPacket left_packet; // initialized iff is_scalar_in_host_memory == true 255 256 inline scalar_left(const scalar_left& other) = default; 257 258 template <typename... Args> 259 EIGEN_DEVICE_FUNC inline explicit scalar_left(const Tin* c, Args... args) 260 : Binary(args...), left(c) { 261 if (is_scalar_in_host_memory) { 262 left_packet = Eigen::internal::pset1<TinPacket>(*left); 263 } 264 } 265 266 EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& right) const { 267 return Binary::operator()(*left, right); 268 } 269 270 template <typename Packet, 271 typename std::enable_if<!is_scalar_in_host_memory || 272 !std::is_same<TinPacket, Packet>::value, 273 int>::type = 0> 274 EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& right_packet) const { 275 const Packet left_packet = Eigen::internal::pset1<Packet>(*left); 276 return Binary::packetOp(left_packet, right_packet); 277 } 278 279 template <typename Packet, 280 typename std::enable_if<is_scalar_in_host_memory && 281 std::is_same<TinPacket, Packet>::value, 282 int>::type = 0> 283 EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& right_packet) const { 284 return Binary::packetOp(left_packet, right_packet); 285 } 286 }; 287 288 template <typename Tout, typename Tin, typename Binary, 289 bool is_scalar_in_host_memory> 290 struct functor_traits< 291 scalar_left<Tout, Tin, Binary, is_scalar_in_host_memory>> { 292 enum { 293 Cost = functor_traits<Binary>::Cost, 294 PacketAccess = functor_traits<Binary>::PacketAccess, 295 }; 296 }; 297 298 template <typename Tout, typename Tin, typename Binary, 299 bool is_scalar_in_host_memory = false> 300 struct scalar_right : private Binary { 301 using result_type = Tout; 302 using TinPacket = typename Eigen::internal::packet_traits<Tin>::type; 303 304 const Tin* right; 305 TinPacket right_packet; // initialized iff is_scalar_in_host_memory == true 306 307 inline scalar_right(const scalar_right& other) = default; 308 309 template <typename... Args> 310 EIGEN_DEVICE_FUNC inline explicit scalar_right(const Tin* c, Args... args) 311 : Binary(args...), right(c) { 312 if (is_scalar_in_host_memory) { 313 right_packet = Eigen::internal::pset1<TinPacket>(*right); 314 } 315 } 316 317 EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& left) const { 318 return Binary::operator()(left, *right); 319 } 320 321 template <typename Packet, 322 typename std::enable_if<!is_scalar_in_host_memory || 323 !std::is_same<TinPacket, Packet>::value, 324 int>::type = 0> 325 EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& left_packet) const { 326 const Packet right_packet = Eigen::internal::pset1<Packet>(*right); 327 return Binary::packetOp(left_packet, right_packet); 328 } 329 330 template <typename Packet, 331 typename std::enable_if<is_scalar_in_host_memory && 332 std::is_same<TinPacket, Packet>::value, 333 int>::type = 0> 334 EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& left_packet) const { 335 return Binary::packetOp(left_packet, right_packet); 336 } 337 }; 338 339 template <typename Tout, typename Tin, typename Binary, 340 bool is_scalar_in_host_memory> 341 struct functor_traits< 342 scalar_right<Tout, Tin, Binary, is_scalar_in_host_memory>> { 343 enum { 344 Cost = functor_traits<Binary>::Cost, 345 PacketAccess = functor_traits<Binary>::PacketAccess, 346 }; 347 }; 348 349 // similar to std::equal_to, but with the DEVICE_FUNC qualifier 350 template <class T> 351 struct equal_to : std::function<bool(T, T)> { 352 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, 353 const T& y) const { 354 return x == y; 355 } 356 }; 357 358 // similar to std::not_equal_to, but with the DEVICE_FUNC qualifier 359 template <class T> 360 struct not_equal_to : std::function<bool(T, T)> { 361 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, 362 const T& y) const { 363 return x != y; 364 } 365 }; 366 367 // similar to std::greater, but with the DEVICE_FUNC qualifier 368 template <class T> 369 struct greater : std::function<bool(T, T)> { 370 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, 371 const T& y) const { 372 return x > y; 373 } 374 }; 375 376 // similar to std::less, but with the DEVICE_FUNC qualifier 377 template <class T> 378 struct less : std::function<bool(T, T)> { 379 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, 380 const T& y) const { 381 return x < y; 382 } 383 }; 384 385 // similar to std::greater_equal, but with the DEVICE_FUNC qualifier 386 template <class T> 387 struct greater_equal : std::function<bool(T, T)> { 388 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, 389 const T& y) const { 390 return x >= y; 391 } 392 }; 393 394 // similar to std::less_equal, but with the DEVICE_FUNC qualifier 395 template <class T> 396 struct less_equal : std::function<bool(T, T)> { 397 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, 398 const T& y) const { 399 return x <= y; 400 } 401 }; 402 403 // Functor that enables squared difference functor. 404 template <typename Scalar> 405 struct scalar_squared_difference_op { 406 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar 407 operator()(const Scalar& a, const Scalar& b) const { 408 const Scalar v = scalar_difference_op<Scalar>()(a, b); 409 return scalar_product_op<Scalar>()(v, scalar_conjugate_op<Scalar>()(v)); 410 } 411 template <typename Packet> 412 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, 413 const Packet& b) const { 414 const Packet v = scalar_difference_op<Scalar>().packetOp(a, b); 415 return scalar_product_op<Scalar>().packetOp( 416 v, scalar_conjugate_op<Scalar>().packetOp(v)); 417 } 418 }; 419 420 template <typename Scalar> 421 struct functor_traits<scalar_squared_difference_op<Scalar>> { 422 enum { 423 Cost = functor_traits<scalar_difference_op<Scalar>>::Cost + 424 functor_traits<scalar_conjugate_op<Scalar>>::Cost + 425 functor_traits<scalar_product_op<Scalar>>::Cost, 426 PacketAccess = functor_traits<scalar_difference_op<Scalar>>::PacketAccess && 427 functor_traits<scalar_conjugate_op<Scalar>>::PacketAccess && 428 functor_traits<scalar_product_op<Scalar>>::PacketAccess 429 }; 430 }; 431 432 // TODO(b/32239616): This kernel should be moved into Eigen and vectorized. 433 template <typename T, typename Enable = void> 434 struct google_floor_div { 435 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, 436 const T& y) const { 437 const T z = x / y; 438 // Subtract one if there is a remainder and if the inputs have opposite 439 // signs. This approach avoids unnecessary overflows. 440 return z * y != x && (x < T(0) != y < T(0)) ? z - T(1) : z; 441 } 442 template <typename Packet> 443 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x, 444 const Packet& y) const { 445 Packet zeros = pzero(x); 446 Packet x_mask = pcmp_lt(x, zeros); 447 Packet y_mask = pcmp_lt(y, zeros); 448 Packet x_div_y = pdiv(x, y); 449 Packet x_div_y_times_y = pmul(x_div_y, y); 450 return pselect(por(peq(x_div_y_times_y, x), peq(x_mask, y_mask)), x_div_y, 451 psub(x_div_y, pones(x))); 452 } 453 }; 454 455 template <typename T> 456 struct google_floor_div< 457 T, typename std::enable_if<std::is_unsigned<T>::value>::type> { 458 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, 459 const T& y) const { 460 return x / y; 461 } 462 template <typename Packet> 463 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x, 464 const Packet& y) const { 465 return pdiv(x, y); 466 } 467 }; 468 469 template <typename Scalar> 470 struct functor_traits<google_floor_div<Scalar>> { 471 enum { 472 Cost = 2 * Eigen::internal::scalar_div_cost< 473 Scalar, packet_traits<Scalar>::HasDiv>::value + 474 NumTraits<Scalar>::AddCost, 475 PacketAccess = packet_traits<Scalar>::HasDiv 476 }; 477 }; 478 479 template <typename T, typename Enable = void> 480 struct google_floor_div_real { 481 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, 482 const T& y) const { 483 return Eigen::numext::floor(x / y); 484 } 485 template <typename Packet> 486 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x, 487 const Packet& y) const { 488 return pfloor(pdiv(x, y)); 489 } 490 }; 491 492 template <typename Scalar> 493 struct functor_traits<google_floor_div_real<Scalar>> { 494 enum { 495 Cost = 2 * Eigen::internal::scalar_div_cost< 496 Scalar, packet_traits<Scalar>::HasDiv>::value + 497 2 * NumTraits<Scalar>::AddCost, 498 PacketAccess = 499 packet_traits<Scalar>::HasDiv && packet_traits<Scalar>::HasFloor 500 }; 501 }; 502 503 // TODO(rmlarsen): Add vectorized mod & fmod in Eigen and use it here. 504 template <typename T> 505 struct google_floor_fmod { 506 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, 507 const T& y) const { 508 // EIGEN_STATIC_ASSERT(NUMERIC_TYPE_MUST_BE_REAL); 509 T trunc_mod = scalar_fmod_op<T>()(x, y); 510 return trunc_mod != T(0) && (y < T(0) != trunc_mod < T(0)) ? trunc_mod + y 511 : trunc_mod; 512 } 513 }; 514 515 template <typename Scalar> 516 struct functor_traits<google_floor_fmod<Scalar>> { 517 enum { 518 Cost = functor_traits<Eigen::internal::scalar_fmod_op<Scalar>>::Cost + 519 NumTraits<Scalar>::AddCost, 520 PacketAccess = false 521 }; 522 }; 523 524 // TODO(rmlarsen): Add vectorized mod & fmod in Eigen and use it here. 525 template <typename T> 526 struct google_floor_mod { 527 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, 528 const T& y) const { 529 // EIGEN_STATIC_ASSERT(!NUMERIC_TYPE_MUST_BE_REAL); 530 T trunc_mod = Eigen::internal::scalar_mod2_op<T>()(x, y); 531 return trunc_mod != T(0) && (y < T(0) != trunc_mod < T(0)) ? trunc_mod + y 532 : trunc_mod; 533 } 534 }; 535 536 template <typename Scalar> 537 struct functor_traits<google_floor_mod<Scalar>> { 538 enum { 539 Cost = functor_traits<Eigen::internal::scalar_mod2_op<Scalar>>::Cost + 540 NumTraits<Scalar>::AddCost, 541 PacketAccess = false 542 }; 543 }; 544 545 #if EIGEN_COMP_GNUC && __cplusplus > 199711L 546 #define DISABLE_FLOAT_EQUALITY_WARNING \ 547 _Pragma("GCC diagnostic push") \ 548 _Pragma("GCC diagnostic ignored \"-Wfloat-equal\"") 549 #define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop") 550 #else 551 #define DISABLE_FLOAT_EQUALITY_WARNING 552 #define ENABLE_FLOAT_EQUALITY_WARNING 553 #endif 554 555 template <typename Scalar, bool IsInteger = Eigen::NumTraits<Scalar>::IsInteger, 556 bool HasRint = packet_traits<Scalar>::HasRint> 557 struct scalar_round_half_to_even_op { 558 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar 559 operator()(const Scalar& x) const { 560 EIGEN_STATIC_ASSERT((!NumTraits<Scalar>::IsComplex), 561 NUMERIC_TYPE_MUST_BE_REAL) 562 563 const Scalar round_val = Eigen::numext::floor(x + Scalar(0.5)); 564 const Scalar fraction = round_val - x; 565 if (TF_PREDICT_FALSE(fraction == Scalar(.5))) { 566 return Scalar(2) * Eigen::numext::floor(Scalar(.5) * x + Scalar(0.5)); 567 } else { 568 return round_val; 569 } 570 } 571 572 template <typename Packet> 573 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const { 574 Packet half = pset1<Packet>(Scalar(0.5)); 575 Packet round_val = pfloor(padd(x, half)); 576 Packet fraction = psub(round_val, x); 577 Packet half_mask = pcmp_eq(fraction, half); 578 bool any_halves = predux_any(half_mask); 579 if (TF_PREDICT_FALSE(any_halves)) { 580 Packet two = pset1<Packet>(Scalar(2)); 581 Packet nearest_even = pmul(two, pfloor(pmadd(half, x, half))); 582 return pselect(half_mask, nearest_even, round_val); 583 } else { 584 return round_val; 585 } 586 } 587 }; 588 589 template <typename Scalar> 590 struct scalar_round_half_to_even_op<Scalar, true, false> { 591 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar 592 operator()(const Scalar& x) const { 593 return x; 594 } 595 template <typename Packet> 596 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const { 597 return x; 598 } 599 }; 600 601 template <typename Scalar> 602 struct scalar_round_half_to_even_op<Scalar, false, true> { 603 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar 604 operator()(const Scalar& x) const { 605 return Eigen::numext::rint(x); 606 } 607 template <typename Packet> 608 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const { 609 return print(x); 610 } 611 }; 612 613 template <typename Scalar> 614 struct functor_traits<scalar_round_half_to_even_op<Scalar>> { 615 enum { 616 Cost = Eigen::NumTraits<Scalar>::IsInteger ? 0 617 : 4 * NumTraits<Scalar>::AddCost, 618 PacketAccess = packet_traits<Scalar>::HasFloor && 619 packet_traits<Scalar>::HasAdd && 620 packet_traits<Scalar>::HasMul, 621 }; 622 }; 623 624 template <typename Scalar, bool IsInteger = Eigen::NumTraits<Scalar>::IsInteger> 625 struct scalar_round_up_op { 626 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar 627 operator()(const Scalar& x) const { 628 EIGEN_STATIC_ASSERT((!NumTraits<Scalar>::IsComplex), 629 NUMERIC_TYPE_MUST_BE_REAL) 630 return Eigen::numext::floor(x + Scalar(0.5)); 631 } 632 633 template <typename Packet> 634 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const { 635 return pfloor(padd(x, pset1<Packet>(0.5))); 636 } 637 }; 638 639 template <typename Scalar> 640 struct scalar_round_up_op<Scalar, true> { 641 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar 642 operator()(const Scalar& x) const { 643 return x; 644 } 645 646 template <typename Packet> 647 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const { 648 return x; 649 } 650 }; 651 652 template <typename Scalar, bool IsInteger> 653 struct functor_traits<scalar_round_up_op<Scalar, IsInteger>> { 654 enum { 655 Cost = IsInteger ? 0 : 4 * NumTraits<Scalar>::AddCost, 656 PacketAccess = IsInteger || packet_traits<Scalar>::HasFloor 657 }; 658 }; 659 660 #undef ENABLE_FLOAT_EQUALITY_WARNING 661 #undef DISABLE_FLOAT_EQUALITY_WARNING 662 663 template <typename Scalar> 664 struct bitwise_xor_op { 665 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar 666 operator()(const Scalar& x, const Scalar& y) const { 667 return x ^ y; 668 } 669 template <typename Packet> 670 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, 671 const Packet& b) const { 672 return Eigen::internal::pxor(a, b); 673 } 674 }; 675 676 template <typename Scalar> 677 struct functor_traits<bitwise_xor_op<Scalar>> { 678 enum { Cost = Eigen::NumTraits<Scalar>::AddCost, PacketAccess = true }; 679 }; 680 681 template <typename Scalar> 682 struct xlogy_op { 683 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar 684 operator()(const Scalar& x, const Scalar& y) const { 685 if (x == Scalar(0.)) { 686 return Scalar(0.); 687 } 688 return x * numext::log(y); 689 } 690 template <typename Packet> 691 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x, 692 const Packet& y) const { 693 Packet zeros = pzero(x); 694 Packet mask = pcmp_eq(x, zeros); 695 scalar_log_op<Scalar> log_op; 696 Packet log_y = log_op.packetOp(y); 697 Packet x_log_y = pmul(x, log_y); 698 return pselect(mask, x, x_log_y); 699 } 700 }; 701 702 template <typename Scalar> 703 struct functor_traits<xlogy_op<Scalar>> { 704 enum { 705 Cost = functor_traits<scalar_log_op<Scalar>>::Cost + 706 Eigen::NumTraits<Scalar>::MulCost, 707 PacketAccess = functor_traits<scalar_log_op<Scalar>>::PacketAccess 708 }; 709 }; 710 711 template <typename Scalar> 712 struct xlog1py_op { 713 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar 714 operator()(const Scalar& x, const Scalar& y) const { 715 if (x == Scalar(0.)) { 716 return Scalar(0.); 717 } 718 return x * numext::log1p(y); 719 } 720 template <typename Packet> 721 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x, 722 const Packet& y) const { 723 Packet zeros = pzero(x); 724 Packet mask = pcmp_eq(x, zeros); 725 scalar_log1p_op<Scalar> log1p_op; 726 Packet log1p_y = log1p_op.packetOp(y); 727 Packet x_log1p_y = pmul(x, log1p_y); 728 return pselect(mask, x, x_log1p_y); 729 } 730 }; 731 732 template <typename Scalar> 733 struct functor_traits<xlog1py_op<Scalar>> { 734 enum { 735 Cost = functor_traits<scalar_log1p_op<Scalar>>::Cost + 736 Eigen::NumTraits<Scalar>::MulCost, 737 #if TENSORFLOW_USE_ROCM 738 PacketAccess = false, 739 #else 740 PacketAccess = functor_traits<scalar_log1p_op<Scalar>>::PacketAccess 741 #endif 742 }; 743 }; 744 745 template <typename Scalar> 746 struct xdivy_op { 747 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar 748 operator()(const Scalar& x, const Scalar& y) const { 749 if (x == Scalar(0.)) { 750 return Scalar(0.); 751 } 752 return x / y; 753 } 754 template <typename Packet> 755 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x, 756 const Packet& y) const { 757 Packet zeros = pzero(x); 758 Packet mask = pcmp_eq(x, zeros); 759 Packet x_div_y = pdiv(x, y); 760 return pselect(mask, x, x_div_y); 761 } 762 }; 763 764 template <typename Scalar> 765 struct functor_traits<xdivy_op<Scalar>> { 766 enum { 767 Cost = 768 Eigen::NumTraits<Scalar>::AddCost + 769 Eigen::internal::scalar_div_cost<Scalar, 770 packet_traits<Scalar>::HasDiv>::value, 771 PacketAccess = packet_traits<Scalar>::HasDiv 772 }; 773 }; 774 775 template <typename T> 776 struct scalar_erfinv_op { 777 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const { 778 constexpr T half = T(0.5); 779 T y = numext::ndtri(half * x + half); 780 constexpr T half_sqrt = T(M_SQRT1_2); 781 return y * half_sqrt; 782 } 783 template <typename Packet> 784 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const { 785 Packet half = pset1<Packet>(T(0.5)); 786 Packet y = pndtri<Packet>(pmadd(half, x, half)); 787 Packet half_sqrt = pset1<Packet>(T(M_SQRT1_2)); 788 return pmul(y, half_sqrt); 789 } 790 }; 791 792 template <typename T> 793 struct functor_traits<scalar_erfinv_op<T>> { 794 enum { 795 Cost = functor_traits<scalar_ndtri_op<T>>::Cost + NumTraits<T>::AddCost, 796 PacketAccess = packet_traits<T>::HasNdtri, 797 }; 798 }; 799 800 } // end namespace internal 801 } // end namespace Eigen 802 803 namespace tensorflow { 804 namespace functor { 805 806 //////////////////////////////////////////////////////////////////////////////// 807 // Helpers 808 //////////////////////////////////////////////////////////////////////////////// 809 810 // Base template for functors whose input scalar type is T and 811 // output scalar type is R. 812 template <typename T, typename F, typename R = T> 813 struct base { 814 // func defines operator() and its vectorized version packetOp(). 815 typedef F func; 816 817 // If true, the functor's corresponding binary op will instantiate 818 // specialized kernels to perform an optimized broadcast 819 // operation. Each functor for which this is enabled increases the 820 // code size, so by default this is disabled for binary functors and 821 // is enabled on a per-op basis as needed. 822 static constexpr bool use_bcast_optimization = false; 823 824 // operator() has the signature: 825 // out_type operator()(in_type in0, in_type in1 ...) 826 typedef R out_type; 827 typedef T in_type; 828 829 // TensorFlow provides tensor-ized version of "func". Roughly 830 // speaking, the tensorflow operation has the signature: 831 // tout_type op(tin_type in0) 832 // tout_type op(tin_type in0, tin_type in1) 833 // tout_type op(tin_type in0, in_type scalar) 834 typedef typename TTypes<out_type>::Flat tout_type; 835 typedef typename TTypes<in_type>::ConstFlat tin_type; 836 typedef typename TTypes<in_type>::ConstScalar tscalar_type; 837 838 // Whether the functor can error out. Currently applies only to integer 839 // div and mod. 840 static constexpr bool has_errors = false; 841 }; 842 843 // For now, we only apply certain speed optimization for 844 // float/double's broadcast binary op. 845 template <typename T> 846 struct use_bcast_optimization { 847 static constexpr bool value = false; 848 }; 849 850 template <> 851 struct use_bcast_optimization<float> { 852 static constexpr bool value = true; 853 }; 854 855 template <> 856 struct use_bcast_optimization<double> { 857 static constexpr bool value = true; 858 }; 859 860 //////////////////////////////////////////////////////////////////////////////// 861 // Unary functors 862 //////////////////////////////////////////////////////////////////////////////// 863 864 // abs(x) = |x| 865 // neg(x) = - x 866 // inverse(x) = 1 / x 867 // square(x) = x^2 868 // sqrt(x) = x^(1/2) 869 // rsqrt(x) = x^(-1/2) 870 // exp(x) = e^x 871 // expm1(x) = e^x - 1 872 // log(x) = natural logarithm of x 873 // log1p(x) = natural logarithm of 1 + x 874 // tanh = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) 875 // sigmoid = 1 / (1 + exp(-x)) // a.k.a, logistic 876 // 877 // NOTE: We may eventually implement common functions used in NN 878 // here. E.g., rectifier, softplus, derivatives of tanh, sigmod, etc. 879 // For reference, see speech/lstm/eigen_functors.h. 880 881 template <typename T> 882 struct abs : base<T, Eigen::internal::scalar_abs_op<T>, 883 typename Eigen::internal::scalar_abs_op<T>::result_type> {}; 884 885 template <typename T> 886 struct neg : base<T, Eigen::internal::scalar_opposite_op<T>> {}; 887 888 template <typename T> 889 struct inverse : base<T, Eigen::internal::scalar_inverse_op<T>> {}; 890 891 template <typename T> 892 struct square : base<T, Eigen::internal::scalar_square_op<T>> {}; 893 894 template <typename T> 895 struct sqrt : base<T, Eigen::internal::scalar_sqrt_op<T>> {}; 896 897 template <typename T> 898 struct rsqrt : base<T, Eigen::internal::scalar_rsqrt_op<T>> {}; 899 900 template <typename T> 901 struct exp : base<T, Eigen::internal::scalar_exp_op<T>> {}; 902 903 template <typename T> 904 struct expm1 : base<T, Eigen::internal::scalar_expm1_op<T>> {}; 905 906 template <typename T> 907 struct log : base<T, Eigen::internal::scalar_log_op<T>> {}; 908 909 template <typename T> 910 struct log1p : base<T, Eigen::internal::scalar_log1p_op<T>> {}; 911 912 template <typename T> 913 struct sign : base<T, Eigen::internal::scalar_sign_op<T>> {}; 914 915 template <typename T> 916 struct sinh : base<T, Eigen::internal::scalar_sinh_op<T>> {}; 917 918 template <typename T> 919 struct cosh : base<T, Eigen::internal::scalar_cosh_op<T>> {}; 920 921 template <typename T> 922 struct tanh : base<T, Eigen::internal::scalar_tanh_op<T>> {}; 923 924 template <typename T> 925 struct asinh : base<T, Eigen::internal::scalar_asinh_op<T>> {}; 926 927 template <typename T> 928 struct acosh : base<T, Eigen::internal::scalar_acosh_op<T>> {}; 929 930 template <typename T> 931 struct atanh : base<T, Eigen::internal::scalar_atanh_op<T>> {}; 932 933 template <typename T> 934 struct lgamma : base<T, Eigen::internal::scalar_lgamma_op<T>> {}; 935 936 template <typename T> 937 struct digamma : base<T, Eigen::internal::scalar_digamma_op<T>> {}; 938 939 template <typename T> 940 struct erf : base<T, Eigen::internal::scalar_erf_op<T>> {}; 941 942 template <typename T> 943 struct erfc : base<T, Eigen::internal::scalar_erfc_op<T>> {}; 944 945 template <typename T> 946 struct ndtri : base<T, Eigen::internal::scalar_ndtri_op<T>> {}; 947 948 template <typename T> 949 struct erfinv : base<T, Eigen::internal::scalar_erfinv_op<T>> {}; 950 951 template <typename T> 952 struct sigmoid : base<T, Eigen::internal::scalar_logistic_op<T>> {}; 953 954 template <typename T> 955 struct sin : base<T, Eigen::internal::scalar_sin_op<T>> {}; 956 957 template <typename T> 958 struct cos : base<T, Eigen::internal::scalar_cos_op<T>> {}; 959 960 template <typename T> 961 struct tan : base<T, Eigen::internal::scalar_tan_op<T>> {}; 962 963 template <typename T> 964 struct asin : base<T, Eigen::internal::scalar_asin_op<T>> {}; 965 966 template <typename T> 967 struct acos : base<T, Eigen::internal::scalar_acos_op<T>> {}; 968 969 template <typename T> 970 struct atan : base<T, Eigen::internal::scalar_atan_op<T>> {}; 971 972 struct logical_not : base<bool, Eigen::internal::scalar_boolean_not_op<bool>> { 973 }; 974 975 // Flip all bits. Named invert to be consistent with numpy. 976 template <typename T> 977 struct invert_op { 978 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a) const { 979 return ~a; 980 } 981 }; 982 983 template <typename T> 984 struct invert : base<T, invert_op<T>> {}; 985 986 // NOTE: std::isinf, std::isnan, std::isfinite are plain function. 987 // Therefore we need to wrap them in functors to be used with Eigen's 988 // type system. 989 template <typename T> 990 struct isinf : base<T, Eigen::internal::scalar_isinf_op<T>, bool> {}; 991 992 template <typename T> 993 struct isnan : base<T, Eigen::internal::scalar_isnan_op<T>, bool> {}; 994 995 template <typename T> 996 struct isfinite : base<T, Eigen::internal::scalar_isfinite_op<T>, bool> {}; 997 998 template <typename T> 999 struct floor : base<T, Eigen::internal::scalar_floor_op<T>> {}; 1000 1001 template <typename T> 1002 struct round : base<T, Eigen::internal::scalar_round_half_to_even_op<T>> {}; 1003 1004 template <typename T> 1005 struct ceil : base<T, Eigen::internal::scalar_ceil_op<T>> {}; 1006 1007 // Note: rint rounds half values to even, just like round_half_to_even_op. 1008 template <typename T> 1009 struct rint : base<T, Eigen::internal::scalar_rint_op<T>> {}; 1010 1011 //////////////////////////////////////////////////////////////////////////////// 1012 // Binary functors 1013 //////////////////////////////////////////////////////////////////////////////// 1014 1015 // Binary functors: 1016 // 1017 // add(x, y) = x + y 1018 // sub(x, y) = x - y 1019 // mul(x, y) = x * y 1020 // div(x, y) = x / y 1021 // mod(x, y) = x % y (int32 and int64 only) 1022 // fmod(x, y) = fmod(x, y) (float and double only) 1023 // pow(x, y) = x ^ y 1024 // maximum(x, y) = x > y ? x : y 1025 // minimum(x, y) = x < y ? x : y 1026 // squared_difference(x, y) = conj(x - y) * (x - y) 1027 1028 template <typename T> 1029 struct add : base<T, Eigen::internal::scalar_sum_op<T>> { 1030 static constexpr bool use_bcast_optimization = true; 1031 }; 1032 1033 template <typename T> 1034 struct sub : base<T, Eigen::internal::scalar_difference_op<T>> { 1035 static constexpr bool use_bcast_optimization = true; 1036 }; 1037 1038 template <typename T> 1039 struct mul : base<T, Eigen::internal::scalar_product_op<T>> { 1040 static constexpr bool use_bcast_optimization = true; 1041 }; 1042 1043 template <typename T> 1044 struct mul_no_nan : base<T, Eigen::internal::mul_no_nan_op<T>> {}; 1045 1046 template <typename T> 1047 struct div : base<T, Eigen::internal::scalar_quotient_op<T>> {}; 1048 1049 template <typename T> 1050 struct safe_div : base<T, Eigen::internal::safe_div_or_mod_op< 1051 T, Eigen::internal::scalar_quotient_op<T>>> { 1052 static constexpr bool has_errors = true; 1053 }; 1054 1055 template <typename T> 1056 struct div_no_nan : base<T, Eigen::internal::div_no_nan_op<T>> {}; 1057 1058 template <typename T> 1059 struct fmod : base<T, Eigen::internal::scalar_fmod_op<T>> {}; 1060 1061 template <typename T> 1062 struct mod : base<T, Eigen::internal::scalar_mod2_op<T>> {}; 1063 1064 template <typename T> 1065 struct safe_mod : base<T, Eigen::internal::safe_div_or_mod_op< 1066 T, Eigen::internal::scalar_mod2_op<T>>> { 1067 static constexpr bool has_errors = true; 1068 }; 1069 1070 template <typename T> 1071 struct floor_fmod : base<T, Eigen::internal::google_floor_fmod<T>> {}; 1072 1073 template <typename T> 1074 struct safe_floor_mod : base<T, Eigen::internal::safe_div_or_mod_op< 1075 T, Eigen::internal::google_floor_mod<T>>> { 1076 static constexpr bool has_errors = true; 1077 }; 1078 1079 template <typename T> 1080 struct floor_div : base<T, Eigen::internal::google_floor_div<T>> {}; 1081 1082 template <typename T> 1083 struct safe_floor_div : base<T, Eigen::internal::safe_div_or_mod_op< 1084 T, Eigen::internal::google_floor_div<T>>> { 1085 static constexpr bool has_errors = true; 1086 }; 1087 1088 template <typename T> 1089 struct floor_div_real : base<T, Eigen::internal::google_floor_div_real<T>> {}; 1090 1091 template <typename T> 1092 struct pow : base<T, Eigen::internal::scalar_pow_op<T, T>> {}; 1093 1094 template <typename T> 1095 struct safe_pow : base<T, Eigen::internal::safe_scalar_binary_pow_op<T, T>> { 1096 static constexpr bool has_errors = true; 1097 }; 1098 1099 // Version of safe_pow for integers which returns 0 if RHS is negative and LHS 1100 // is not 1 or -1. For use on GPUs, where we cannot raise an error. 1101 template <typename T> 1102 struct safe_pow_ignore_error_op { 1103 static_assert(std::is_integral<T>::value, "Integer type expected"); 1104 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, 1105 const T& y) const { 1106 if (TF_PREDICT_FALSE(y < 0)) { 1107 if (x == T(-1)) { 1108 T trunc_mod = Eigen::internal::scalar_mod2_op<T>()(y, T(2)); 1109 return trunc_mod == T(-1) ? T(-1) : T(1); 1110 } 1111 return x == T(1) ? T(1) : T(0); 1112 } 1113 return Eigen::internal::scalar_pow_op<T, T>{}(x, y); 1114 } 1115 }; 1116 1117 template <typename T> 1118 struct safe_pow_ignore_error : base<T, safe_pow_ignore_error_op<T>> {}; 1119 1120 template <typename T> 1121 struct maximum 1122 : base<T, Eigen::internal::scalar_max_op<T, T, Eigen::PropagateNaN>> {}; 1123 1124 template <typename T> 1125 struct minimum 1126 : base<T, Eigen::internal::scalar_min_op<T, T, Eigen::PropagateNaN>> {}; 1127 1128 template <typename T> 1129 struct igamma : base<T, Eigen::internal::scalar_igamma_op<T>> {}; 1130 1131 template <typename T> 1132 struct random_gamma_grad 1133 : base<T, Eigen::internal::scalar_gamma_sample_der_alpha_op<T>> {}; 1134 1135 template <typename T> 1136 struct igammac : base<T, Eigen::internal::scalar_igammac_op<T>> {}; 1137 1138 template <typename T> 1139 struct zeta : base<T, Eigen::internal::scalar_zeta_op<T>> {}; 1140 1141 template <typename T> 1142 struct polygamma : base<T, Eigen::internal::scalar_polygamma_op<T>> {}; 1143 1144 template <typename Scalar> 1145 struct scalar_atan2_op { 1146 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar 1147 operator()(const Scalar& y, const Scalar& x) const { 1148 #if TENSORFLOW_USE_ROCM 1149 return static_cast<Scalar>(::atan2(y, x)); 1150 #else 1151 return static_cast<Scalar>(std::atan2(y, x)); 1152 #endif 1153 } 1154 }; 1155 1156 template <typename T> 1157 struct atan2 : base<T, scalar_atan2_op<T>> {}; 1158 1159 template <typename T> 1160 struct squared_difference 1161 : base<T, Eigen::internal::scalar_squared_difference_op<T>> {}; 1162 1163 template <typename T> 1164 struct xdivy : base<T, Eigen::internal::xdivy_op<T>> {}; 1165 1166 template <typename T> 1167 struct xlogy : base<T, Eigen::internal::xlogy_op<T>> {}; 1168 1169 template <typename T> 1170 struct xlog1py : base<T, Eigen::internal::xlog1py_op<T>> {}; 1171 1172 template <typename T> 1173 struct less : base<T, Eigen::internal::less<T>, bool> {}; 1174 1175 template <typename T> 1176 struct less_equal : base<T, Eigen::internal::less_equal<T>, bool> {}; 1177 1178 template <typename T> 1179 struct greater : base<T, Eigen::internal::greater<T>, bool> {}; 1180 1181 template <typename T> 1182 struct greater_equal : base<T, Eigen::internal::greater_equal<T>, bool> {}; 1183 1184 template <typename T> 1185 struct equal_to : base<T, Eigen::internal::equal_to<T>, bool> {}; 1186 1187 template <typename T> 1188 struct not_equal_to : base<T, Eigen::internal::not_equal_to<T>, bool> {}; 1189 1190 struct logical_and : base<bool, Eigen::internal::scalar_boolean_and_op> {}; 1191 1192 struct logical_or : base<bool, Eigen::internal::scalar_boolean_or_op> {}; 1193 1194 template <typename T> 1195 struct bitwise_and_op { 1196 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, 1197 const T& y) const { 1198 return x & y; 1199 } 1200 }; 1201 1202 template <typename T> 1203 struct bitwise_or_op { 1204 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, 1205 const T& y) const { 1206 return x | y; 1207 } 1208 }; 1209 1210 template <typename T> 1211 struct bitwise_and : base<T, bitwise_and_op<T>> {}; 1212 1213 template <typename T> 1214 struct bitwise_or : base<T, bitwise_or_op<T>> {}; 1215 1216 template <typename T> 1217 struct bitwise_xor : base<T, Eigen::internal::bitwise_xor_op<T>> {}; 1218 1219 template <typename T> 1220 struct left_shift_op { 1221 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, 1222 const T& y) const { 1223 // Avoids UB: don't shift by larger than the bitwidth of T, and 1224 // performs left shifts as unsigned shifts. 1225 T y_clamped = y; 1226 if (y_clamped < 0) { 1227 y_clamped = 0; 1228 } else if (y_clamped > sizeof(T) * CHAR_BIT - 1) { 1229 y_clamped = sizeof(T) * CHAR_BIT - 1; 1230 } 1231 using U = typename std::make_unsigned<T>::type; 1232 return static_cast<T>(static_cast<U>(x) << static_cast<U>(y_clamped)); 1233 } 1234 }; 1235 1236 template <typename T> 1237 struct right_shift_op { 1238 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, 1239 const T& y) const { 1240 // Avoids UB: don't shift by larger than the bitwidth of T. 1241 T y_clamped = y; 1242 if (y_clamped < 0) { 1243 y_clamped = 0; 1244 } else if (y_clamped > sizeof(T) * CHAR_BIT - 1) { 1245 y_clamped = sizeof(T) * CHAR_BIT - 1; 1246 } 1247 // Technically right shifts of signed integers are not necessarily 1248 // arithmetic shifts according to the C++ standard. However in practice most 1249 // implementations are arithmetic shifts. If this proves to be a problem in 1250 // practice, we may need to use an alternative implementation. 1251 return x >> y_clamped; 1252 } 1253 }; 1254 1255 template <typename T> 1256 struct left_shift : base<T, left_shift_op<T>> {}; 1257 1258 template <typename T> 1259 struct right_shift : base<T, right_shift_op<T>> {}; 1260 1261 template <typename T> 1262 struct make_complex_func { 1263 typedef std::complex<T> result_type; 1264 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(T real, 1265 T imag) const { 1266 return std::complex<T>(real, imag); 1267 } 1268 }; 1269 1270 template <typename T> 1271 struct make_complex : base<T, make_complex_func<T>, std::complex<T>> {}; 1272 1273 template <typename T> 1274 struct get_real 1275 : base<T, Eigen::internal::scalar_real_op<T>, typename T::value_type> {}; 1276 1277 template <typename T> 1278 struct get_imag 1279 : base<T, Eigen::internal::scalar_imag_op<T>, typename T::value_type> {}; 1280 1281 template <typename T> 1282 struct get_angle 1283 : base<T, Eigen::internal::scalar_arg_op<T>, typename T::value_type> {}; 1284 1285 template <typename T> 1286 struct conj : base<T, Eigen::internal::scalar_conjugate_op<T>> {}; 1287 1288 //////////////////////////////////////////////////////////////////////////////// 1289 // Functors takes 1 or 2 tensors, computes the base functor on 1290 // coefficient of the input tensors and puts the results in the output 1291 // tensor. 1292 //////////////////////////////////////////////////////////////////////////////// 1293 template <typename Device, typename Functor> 1294 struct UnaryFunctor { 1295 // Computes on device "d": out[i] = Functor(in[i]) 1296 void operator()(const Device& d, typename Functor::tout_type out, 1297 typename Functor::tin_type in); 1298 }; 1299 1300 template <typename Device, typename Functor, typename Targ> 1301 struct UnaryFunctorWithArg { 1302 // Computes on device "d": out[i] = Functor(in[i]) 1303 void operator()(const Device& d, typename Functor::tout_type out, 1304 typename Functor::tin_type in, Targ val); 1305 }; 1306 1307 template <typename Device, typename Functor, int NDIMS, 1308 bool has_errors = Functor::has_errors> 1309 struct BinaryFunctor { 1310 // Computes on device "d": out[i] = Functor(in0[i], in1[i]) 1311 void operator()(const Device& d, typename Functor::tout_type out, 1312 typename Functor::tin_type in0, 1313 typename Functor::tin_type in1, bool* error); 1314 1315 // Computes on device "d": out[i] = Functor(scalar[0], in[i]) 1316 void Left(const Device& d, typename Functor::tout_type out, 1317 typename Functor::tscalar_type scalar, 1318 typename Functor::tin_type in, bool* error); 1319 1320 // Computes on device "d": out[i] = Functor(in[i], scalar[0]) 1321 void Right(const Device& d, typename Functor::tout_type out, 1322 typename Functor::tin_type in, 1323 typename Functor::tscalar_type scalar, bool* error); 1324 1325 // Computes on device "d": 1326 // out = Functor(in0.broadcast(bcast0), in1.broadcast(bcast1)) 1327 // 1328 // TODO(zhifengc): makes BCast a template member function on NDIMS 1329 // instead making BinaryFunctor templates on NDIMS. 1330 void BCast(const Device& d, 1331 typename TTypes<typename Functor::out_type, NDIMS>::Tensor out, 1332 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0, 1333 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0, 1334 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1, 1335 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1, 1336 bool* error); 1337 }; 1338 1339 template <typename Device, typename T> 1340 struct ApproximateEqual { 1341 void operator()(const Device& d, typename TTypes<T>::ConstFlat x, 1342 typename TTypes<T>::ConstFlat y, T tolerance, 1343 typename TTypes<bool>::Flat z); 1344 }; 1345 1346 template <int NDIMS> 1347 bool AllOne(const typename Eigen::array<Eigen::DenseIndex, NDIMS>& a) { 1348 for (size_t i = 0; i < a.size(); ++i) { 1349 if (a[i] != 1) return false; 1350 } 1351 return true; 1352 } 1353 1354 template <typename Device, typename T> 1355 struct SelectFunctor { 1356 void operator()(const Device& d, typename TTypes<T>::Flat out, 1357 typename TTypes<bool>::ConstFlat cond_flat, 1358 typename TTypes<T>::ConstFlat then_flat, 1359 typename TTypes<T>::ConstFlat else_flat); 1360 }; 1361 1362 template <typename Device, typename T> 1363 struct SelectScalarFunctor { 1364 void operator()(const Device& d, typename TTypes<T>::Flat out, 1365 typename TTypes<bool>::ConstScalar cond, 1366 typename TTypes<T>::ConstFlat then_flat, 1367 typename TTypes<T>::ConstFlat else_flat); 1368 }; 1369 1370 template <typename Device, typename T> 1371 struct BatchSelectFunctor { 1372 void operator()(const Device& d, 1373 typename TTypes<T>::Matrix output_flat_outer_dims, 1374 TTypes<bool>::ConstVec cond_vec, 1375 typename TTypes<T>::ConstMatrix then_flat_outer_dims, 1376 typename TTypes<T>::ConstMatrix else_flat_outer_dims); 1377 }; 1378 1379 template <typename Device, typename T, int NDIMS> 1380 struct BCastSelectFunctor { 1381 void operator()(const Device& d, 1382 typename TTypes<T, NDIMS>::Tensor output_tensor, 1383 typename TTypes<bool, NDIMS>::ConstTensor cond_tensor, 1384 typename TTypes<T, NDIMS>::ConstTensor then_tensor, 1385 typename TTypes<T, NDIMS>::ConstTensor else_tensor, 1386 typename Eigen::array<Eigen::DenseIndex, NDIMS> cond_bcast, 1387 typename Eigen::array<Eigen::DenseIndex, NDIMS> then_bcast, 1388 typename Eigen::array<Eigen::DenseIndex, NDIMS> else_bcast); 1389 }; 1390 1391 } // end namespace functor 1392 } // end namespace tensorflow 1393 1394 #endif // TENSORFLOW_CORE_KERNELS_CWISE_OPS_H_ 1395