1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_H_ 17 #define TENSORFLOW_CORE_KERNELS_CWISE_OPS_H_ 18 19 #define _USE_MATH_DEFINES 20 #include <cmath> 21 #include <functional> 22 #include <type_traits> 23 24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 25 #include "tensorflow/core/framework/bounds_check.h" 26 #include "tensorflow/core/framework/numeric_types.h" 27 #include "tensorflow/core/framework/tensor_types.h" 28 29 namespace Eigen { 30 namespace internal { 31 32 #if GOOGLE_CUDA 33 template <> 34 struct scalar_arg_op<std::complex<float>> { 35 EIGEN_EMPTY_STRUCT_CTOR(scalar_arg_op) 36 typedef typename Eigen::NumTraits<std::complex<float>>::Real result_type; 37 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()( 38 const std::complex<float>& a) const { 39 return ::atan2f(a.imag(), a.real()); 40 } 41 }; 42 43 template <> 44 struct scalar_arg_op<std::complex<double>> { 45 EIGEN_EMPTY_STRUCT_CTOR(scalar_arg_op) 46 typedef typename Eigen::NumTraits<std::complex<double>>::Real result_type; 47 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double operator()( 48 const std::complex<double>& a) const { 49 return ::atan2(a.imag(), a.real()); 50 } 51 }; 52 #endif 53 54 #if EIGEN_HAS_CXX11_MATH == 0 55 template <typename T> 56 struct scalar_asinh_op { 57 EIGEN_EMPTY_STRUCT_CTOR(scalar_asinh_op) 58 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a) const { 59 return std::asinh(a); 60 } 61 }; 62 template <typename T> 63 struct functor_traits<scalar_asinh_op<T>> { 64 enum { Cost = 5 * NumTraits<T>::MulCost, PacketAccess = false }; 65 }; 66 67 template <typename T> 68 struct scalar_acosh_op { 69 EIGEN_EMPTY_STRUCT_CTOR(scalar_acosh_op) 70 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a) const { 71 return std::acosh(a); 72 } 73 }; 74 template <typename T> 75 struct functor_traits<scalar_acosh_op<T>> { 76 enum { Cost = 5 * NumTraits<T>::MulCost, PacketAccess = false }; 77 }; 78 79 template <typename T> 80 struct scalar_atanh_op { 81 EIGEN_EMPTY_STRUCT_CTOR(scalar_atanh_op) 82 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a) const { 83 return std::atanh(a); 84 } 85 }; 86 template <typename T> 87 struct functor_traits<scalar_atanh_op<T>> { 88 enum { Cost = 5 * NumTraits<T>::MulCost, PacketAccess = false }; 89 }; 90 #endif 91 92 template <typename Scalar, typename Exponent> 93 struct safe_scalar_binary_pow_op { 94 static_assert(std::is_integral<Scalar>::value, "Integer type expected"); 95 static_assert(std::is_integral<Exponent>::value && 96 std::is_signed<Exponent>::value, 97 "Signed integer type expected"); 98 99 bool* const error; 100 101 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_scalar_binary_pow_op(bool* error) 102 : error(error) {} 103 104 EIGEN_DEVICE_FUNC inline Scalar operator()(const Scalar& a, 105 const Exponent& b) const { 106 const Exponent safe_b = tensorflow::internal::SubtleMustCopy(b); 107 if (TF_PREDICT_TRUE(safe_b >= 0)) { 108 return numext::pow(a, safe_b); 109 } else { 110 *error = true; 111 return 0; 112 } 113 } 114 }; 115 116 template <typename Scalar, typename Exponent> 117 struct functor_traits<safe_scalar_binary_pow_op<Scalar, Exponent>> { 118 enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false }; 119 }; 120 121 template <typename T, typename DivOrMod> 122 struct safe_div_or_mod_op { 123 static_assert(std::is_integral<T>::value, "Integer type expected"); 124 125 bool* const error; 126 127 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_div_or_mod_op(bool* error) 128 : error(error) {} 129 130 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a, 131 const T& b) const { 132 const T safe_b = tensorflow::internal::SubtleMustCopy(b); 133 if (TF_PREDICT_TRUE(safe_b != 0)) { 134 return DivOrMod()(a, safe_b); 135 } else { 136 *error = true; 137 return 0; 138 } 139 } 140 }; 141 142 template <typename T, typename DivOrMod> 143 struct functor_traits<safe_div_or_mod_op<T, DivOrMod>> { 144 enum { 145 Cost = functor_traits<DivOrMod>::Cost + NumTraits<T>::AddCost, 146 PacketAccess = false, 147 }; 148 }; 149 150 template <typename T, typename Binary> 151 struct no_nan_op { 152 EIGEN_EMPTY_STRUCT_CTOR(no_nan_op) 153 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a, 154 const T& b) const { 155 if (b != T(0)) { 156 return Binary()(a, b); 157 } else { 158 return T(0); 159 } 160 } 161 template <typename Packet> 162 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, 163 const Packet& b) const { 164 const Packet mask = pcmp_eq(b, pzero(b)); 165 const Packet quotient = Binary().packetOp(a, b); 166 return pandnot(quotient, mask); 167 } 168 }; 169 170 template <typename T> 171 struct div_no_nan_op : public no_nan_op<T, scalar_quotient_op<T>> { 172 EIGEN_EMPTY_STRUCT_CTOR(div_no_nan_op) 173 }; 174 175 template <typename T> 176 struct functor_traits<div_no_nan_op<T>> { 177 enum { 178 Cost = functor_traits<scalar_quotient_op<T>>::Cost + NumTraits<T>::AddCost, 179 PacketAccess = true, 180 }; 181 }; 182 183 template <typename T> 184 struct mul_no_nan_op : public no_nan_op<T, scalar_product_op<T>> { 185 EIGEN_EMPTY_STRUCT_CTOR(mul_no_nan_op) 186 }; 187 188 template <typename T> 189 struct functor_traits<mul_no_nan_op<T>> { 190 enum { 191 Cost = functor_traits<scalar_product_op<T>>::Cost + NumTraits<T>::AddCost, 192 PacketAccess = true, 193 }; 194 }; 195 196 // scalar_left and scalar_right are template helpers to partially 197 // apply a binary function. 198 // 199 // Suppose Binary is a binary functor f(x, y), scalar_left<> is a 200 // unary functor g_x(y) = f(x, y), where x is provided via the 201 // constructor. Similarly, scalar_right<> is a unary functor g_y(x) = 202 // f(x, y). 203 204 template <typename Tout, typename Tin, typename Binary, 205 bool is_scalar_in_host_memory = false> 206 struct scalar_left : private Binary { 207 using result_type = Tout; 208 using TinPacket = typename Eigen::internal::packet_traits<Tin>::type; 209 210 const Tin* left; 211 TinPacket left_packet; // initialized iff is_scalar_in_host_memory == true 212 213 EIGEN_DEVICE_FUNC inline scalar_left(const scalar_left& other) = default; 214 215 template <typename... Args> 216 EIGEN_DEVICE_FUNC inline explicit scalar_left(const Tin* c, Args... args) 217 : Binary(args...), left(c) { 218 if (is_scalar_in_host_memory) { 219 left_packet = Eigen::internal::pset1<TinPacket>(*left); 220 } 221 } 222 223 EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& right) const { 224 return Binary::operator()(*left, right); 225 } 226 227 template <typename Packet, 228 typename std::enable_if<!is_scalar_in_host_memory || 229 !std::is_same<TinPacket, Packet>::value, 230 int>::type = 0> 231 EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& right_packet) const { 232 const Packet left_packet = Eigen::internal::pset1<Packet>(*left); 233 return Binary::packetOp(left_packet, right_packet); 234 } 235 236 template <typename Packet, 237 typename std::enable_if<is_scalar_in_host_memory && 238 std::is_same<TinPacket, Packet>::value, 239 int>::type = 0> 240 EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& right_packet) const { 241 return Binary::packetOp(left_packet, right_packet); 242 } 243 }; 244 245 template <typename Tout, typename Tin, typename Binary, 246 bool is_scalar_in_host_memory> 247 struct functor_traits< 248 scalar_left<Tout, Tin, Binary, is_scalar_in_host_memory>> { 249 enum { 250 Cost = functor_traits<Binary>::Cost, 251 PacketAccess = functor_traits<Binary>::PacketAccess, 252 }; 253 }; 254 255 template <typename Tout, typename Tin, typename Binary, 256 bool is_scalar_in_host_memory = false> 257 struct scalar_right : private Binary { 258 using result_type = Tout; 259 using TinPacket = typename Eigen::internal::packet_traits<Tin>::type; 260 261 const Tin* right; 262 TinPacket right_packet; // initialized iff is_scalar_in_host_memory == true 263 264 EIGEN_DEVICE_FUNC inline scalar_right(const scalar_right& other) = default; 265 266 template <typename... Args> 267 EIGEN_DEVICE_FUNC inline explicit scalar_right(const Tin* c, Args... args) 268 : Binary(args...), right(c) { 269 if (is_scalar_in_host_memory) { 270 right_packet = Eigen::internal::pset1<TinPacket>(*right); 271 } 272 } 273 274 EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& left) const { 275 return Binary::operator()(left, *right); 276 } 277 278 template <typename Packet, 279 typename std::enable_if<!is_scalar_in_host_memory || 280 !std::is_same<TinPacket, Packet>::value, 281 int>::type = 0> 282 EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& left_packet) const { 283 const Packet right_packet = Eigen::internal::pset1<Packet>(*right); 284 return Binary::packetOp(left_packet, right_packet); 285 } 286 287 template <typename Packet, 288 typename std::enable_if<is_scalar_in_host_memory && 289 std::is_same<TinPacket, Packet>::value, 290 int>::type = 0> 291 EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& left_packet) const { 292 return Binary::packetOp(left_packet, right_packet); 293 } 294 }; 295 296 template <typename Tout, typename Tin, typename Binary, 297 bool is_scalar_in_host_memory> 298 struct functor_traits< 299 scalar_right<Tout, Tin, Binary, is_scalar_in_host_memory>> { 300 enum { 301 Cost = functor_traits<Binary>::Cost, 302 PacketAccess = functor_traits<Binary>::PacketAccess, 303 }; 304 }; 305 306 // similar to std::equal_to, but with the DEVICE_FUNC qualifier 307 template <class T> 308 struct equal_to : std::binary_function<T, T, bool> { 309 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, 310 const T& y) const { 311 return x == y; 312 } 313 }; 314 315 // similar to std::not_equal_to, but with the DEVICE_FUNC qualifier 316 template <class T> 317 struct not_equal_to : std::binary_function<T, T, bool> { 318 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, 319 const T& y) const { 320 return x != y; 321 } 322 }; 323 324 // similar to std::greater, but with the DEVICE_FUNC qualifier 325 template <class T> 326 struct greater : std::binary_function<T, T, bool> { 327 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, 328 const T& y) const { 329 return x > y; 330 } 331 }; 332 333 // similar to std::less, but with the DEVICE_FUNC qualifier 334 template <class T> 335 struct less : std::binary_function<T, T, bool> { 336 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, 337 const T& y) const { 338 return x < y; 339 } 340 }; 341 342 // similar to std::greater_equal, but with the DEVICE_FUNC qualifier 343 template <class T> 344 struct greater_equal : std::binary_function<T, T, bool> { 345 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, 346 const T& y) const { 347 return x >= y; 348 } 349 }; 350 351 // similar to std::less_equal, but with the DEVICE_FUNC qualifier 352 template <class T> 353 struct less_equal : std::binary_function<T, T, bool> { 354 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, 355 const T& y) const { 356 return x <= y; 357 } 358 }; 359 360 // Functor that enables squared difference functor. 361 template <typename Scalar> 362 struct scalar_squared_difference_op { 363 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar 364 operator()(const Scalar& a, const Scalar& b) const { 365 const Scalar v = scalar_difference_op<Scalar>()(a, b); 366 return scalar_product_op<Scalar>()(v, scalar_conjugate_op<Scalar>()(v)); 367 } 368 template <typename Packet> 369 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, 370 const Packet& b) const { 371 const Packet v = scalar_difference_op<Scalar>().packetOp(a, b); 372 return scalar_product_op<Scalar>().packetOp( 373 v, scalar_conjugate_op<Scalar>().packetOp(v)); 374 } 375 }; 376 377 template <typename Scalar> 378 struct functor_traits<scalar_squared_difference_op<Scalar>> { 379 enum { 380 Cost = functor_traits<scalar_difference_op<Scalar>>::Cost + 381 functor_traits<scalar_conjugate_op<Scalar>>::Cost + 382 functor_traits<scalar_product_op<Scalar>>::Cost, 383 PacketAccess = functor_traits<scalar_difference_op<Scalar>>::PacketAccess && 384 functor_traits<scalar_conjugate_op<Scalar>>::PacketAccess && 385 functor_traits<scalar_product_op<Scalar>>::PacketAccess 386 }; 387 }; 388 389 // TODO(b/32239616): This kernel should be moved into Eigen and vectorized. 390 template <typename T, typename Enable = void> 391 struct google_floor_div { 392 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, 393 const T& y) const { 394 if ((x < T(0)) != (y < T(0))) { 395 // HIP does not have the device version of the abs routine defined 396 // for all datatypes that T can resolve to 397 #if defined(__HIP_DEVICE_COMPILE__) 398 T abs_x = (x < T(0)) ? -x : x; 399 T abs_y = (y < T(0)) ? -y : y; 400 #else 401 T abs_x = std::abs(x); 402 T abs_y = std::abs(y); 403 #endif 404 return -(abs_x + abs_y - 1) / abs_y; 405 } else { 406 return x / y; 407 } 408 } 409 template <typename Packet> 410 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x, 411 const Packet& y) const { 412 Packet zeros = pzero(x); 413 Packet x_mask = pcmp_lt(x, zeros); 414 Packet y_mask = pcmp_lt(y, zeros); 415 Packet x_div_y = pdiv(x, y); 416 Packet abs_x = pabs(x); 417 Packet abs_y = pabs(y); 418 Packet ones = pones(x); 419 Packet ratio_rounded = pdiv(pnegate(psub(padd(abs_x, abs_y), ones)), abs_y); 420 return pselect(pxor(x_mask, y_mask), ratio_rounded, x_div_y); 421 } 422 }; 423 424 template <typename T> 425 struct google_floor_div< 426 T, typename std::enable_if<std::is_unsigned<T>::value>::type> { 427 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, 428 const T& y) const { 429 return x / y; 430 } 431 template <typename Packet> 432 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x, 433 const Packet& y) const { 434 return pdiv(x, y); 435 } 436 }; 437 438 template <typename Scalar> 439 struct functor_traits<google_floor_div<Scalar>> { 440 enum { 441 Cost = 2 * Eigen::internal::scalar_div_cost< 442 Scalar, packet_traits<Scalar>::HasDiv>::value + 443 NumTraits<Scalar>::AddCost, 444 PacketAccess = packet_traits<Scalar>::HasDiv 445 }; 446 }; 447 448 template <typename T, typename Enable = void> 449 struct google_floor_div_real { 450 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, 451 const T& y) const { 452 return Eigen::numext::floor(x / y); 453 } 454 template <typename Packet> 455 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x, 456 const Packet& y) const { 457 return pfloor(pdiv(x, y)); 458 } 459 }; 460 461 template <typename Scalar> 462 struct functor_traits<google_floor_div_real<Scalar>> { 463 enum { 464 Cost = 2 * Eigen::internal::scalar_div_cost< 465 Scalar, packet_traits<Scalar>::HasDiv>::value + 466 2 * NumTraits<Scalar>::AddCost, 467 PacketAccess = 468 packet_traits<Scalar>::HasDiv && packet_traits<Scalar>::HasFloor 469 }; 470 }; 471 472 // TODO(rmlarsen): Add vectorized mod & fmod in Eigen and use it here. 473 template <typename T> 474 struct google_floor_fmod { 475 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, 476 const T& y) const { 477 // EIGEN_STATIC_ASSERT(NUMERIC_TYPE_MUST_BE_REAL); 478 T trunc_mod = scalar_fmod_op<T>()(x, y); 479 return trunc_mod != T(0) && (y < T(0) != trunc_mod < T(0)) ? trunc_mod + y 480 : trunc_mod; 481 } 482 }; 483 484 template <typename Scalar> 485 struct functor_traits<google_floor_fmod<Scalar>> { 486 enum { 487 Cost = functor_traits<Eigen::internal::scalar_fmod_op<Scalar>>::Cost + 488 NumTraits<Scalar>::AddCost, 489 PacketAccess = false 490 }; 491 }; 492 493 // TODO(rmlarsen): Add vectorized mod & fmod in Eigen and use it here. 494 template <typename T> 495 struct google_floor_mod { 496 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, 497 const T& y) const { 498 // EIGEN_STATIC_ASSERT(!NUMERIC_TYPE_MUST_BE_REAL); 499 T trunc_mod = Eigen::internal::scalar_mod2_op<T>()(x, y); 500 return trunc_mod != T(0) && (y < T(0) != trunc_mod < T(0)) ? trunc_mod + y 501 : trunc_mod; 502 } 503 }; 504 505 template <typename Scalar> 506 struct functor_traits<google_floor_mod<Scalar>> { 507 enum { 508 Cost = functor_traits<Eigen::internal::scalar_mod2_op<Scalar>>::Cost + 509 NumTraits<Scalar>::AddCost, 510 PacketAccess = false 511 }; 512 }; 513 514 #if EIGEN_COMP_GNUC && __cplusplus > 199711L 515 #define DISABLE_FLOAT_EQUALITY_WARNING \ 516 _Pragma("GCC diagnostic push") \ 517 _Pragma("GCC diagnostic ignored \"-Wfloat-equal\"") 518 #define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop") 519 #else 520 #define DISABLE_FLOAT_EQUALITY_WARNING 521 #define ENABLE_FLOAT_EQUALITY_WARNING 522 #endif 523 524 template <typename Scalar, bool IsInteger = Eigen::NumTraits<Scalar>::IsInteger, 525 bool HasRint = packet_traits<Scalar>::HasRint> 526 struct scalar_round_half_to_even_op { 527 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar 528 operator()(const Scalar& x) const { 529 EIGEN_STATIC_ASSERT((!NumTraits<Scalar>::IsComplex), 530 NUMERIC_TYPE_MUST_BE_REAL) 531 532 const Scalar round_val = Eigen::numext::floor(x + Scalar(0.5)); 533 const Scalar fraction = round_val - x; 534 if (TF_PREDICT_FALSE(fraction == Scalar(.5))) { 535 return Scalar(2) * Eigen::numext::floor(Scalar(.5) * x + Scalar(0.5)); 536 } else { 537 return round_val; 538 } 539 } 540 541 template <typename Packet> 542 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const { 543 Packet half = pset1<Packet>(Scalar(0.5)); 544 Packet round_val = pfloor(padd(x, half)); 545 Packet fraction = psub(round_val, x); 546 Packet half_mask = pcmp_eq(fraction, half); 547 bool any_halves = predux_any(half_mask); 548 if (TF_PREDICT_FALSE(any_halves)) { 549 Packet two = pset1<Packet>(Scalar(2)); 550 Packet nearest_even = pmul(two, pfloor(pmadd(half, x, half))); 551 return pselect(half_mask, nearest_even, round_val); 552 } else { 553 return round_val; 554 } 555 } 556 }; 557 558 template <typename Scalar> 559 struct scalar_round_half_to_even_op<Scalar, true, false> { 560 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar 561 operator()(const Scalar& x) const { 562 return x; 563 } 564 template <typename Packet> 565 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const { 566 return x; 567 } 568 }; 569 570 template <typename Scalar> 571 struct scalar_round_half_to_even_op<Scalar, false, true> { 572 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar 573 operator()(const Scalar& x) const { 574 return Eigen::numext::rint(x); 575 } 576 template <typename Packet> 577 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const { 578 return print(x); 579 } 580 }; 581 582 template <typename Scalar> 583 struct functor_traits<scalar_round_half_to_even_op<Scalar>> { 584 enum { 585 Cost = Eigen::NumTraits<Scalar>::IsInteger ? 0 586 : 4 * NumTraits<Scalar>::AddCost, 587 PacketAccess = packet_traits<Scalar>::HasFloor && 588 packet_traits<Scalar>::HasAdd && 589 packet_traits<Scalar>::HasMul, 590 }; 591 }; 592 593 template <typename Scalar, bool IsInteger = Eigen::NumTraits<Scalar>::IsInteger> 594 struct scalar_round_up_op { 595 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar 596 operator()(const Scalar& x) const { 597 EIGEN_STATIC_ASSERT((!NumTraits<Scalar>::IsComplex), 598 NUMERIC_TYPE_MUST_BE_REAL) 599 return Eigen::numext::floor(x + Scalar(0.5)); 600 } 601 602 template <typename Packet> 603 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const { 604 return pfloor(padd(x, pset1<Packet>(0.5))); 605 } 606 }; 607 608 template <typename Scalar> 609 struct scalar_round_up_op<Scalar, true> { 610 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar 611 operator()(const Scalar& x) const { 612 return x; 613 } 614 615 template <typename Packet> 616 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const { 617 return x; 618 } 619 }; 620 621 template <typename Scalar, bool IsInteger> 622 struct functor_traits<scalar_round_up_op<Scalar, IsInteger>> { 623 enum { 624 Cost = IsInteger ? 0 : 4 * NumTraits<Scalar>::AddCost, 625 PacketAccess = IsInteger || packet_traits<Scalar>::HasFloor 626 }; 627 }; 628 629 #undef ENABLE_FLOAT_EQUALITY_WARNING 630 #undef DISABLE_FLOAT_EQUALITY_WARNING 631 632 template <typename Scalar> 633 struct bitwise_xor_op { 634 EIGEN_EMPTY_STRUCT_CTOR(bitwise_xor_op) 635 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar 636 operator()(const Scalar& x, const Scalar& y) const { 637 return x ^ y; 638 } 639 template <typename Packet> 640 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, 641 const Packet& b) const { 642 return Eigen::internal::pxor(a, b); 643 } 644 }; 645 646 template <typename Scalar> 647 struct functor_traits<bitwise_xor_op<Scalar>> { 648 enum { Cost = Eigen::NumTraits<Scalar>::AddCost, PacketAccess = true }; 649 }; 650 651 template <typename Scalar> 652 struct xlogy_op { 653 EIGEN_EMPTY_STRUCT_CTOR(xlogy_op) 654 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar 655 operator()(const Scalar& x, const Scalar& y) const { 656 if (x == Scalar(0.)) { 657 return Scalar(0.); 658 } 659 return x * numext::log(y); 660 } 661 template <typename Packet> 662 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x, 663 const Packet& y) const { 664 Packet zeros = pzero(x); 665 Packet mask = pcmp_eq(x, zeros); 666 scalar_log_op<Scalar> log_op; 667 Packet log_y = log_op.packetOp(y); 668 Packet x_log_y = pmul(x, log_y); 669 return pselect(mask, x, x_log_y); 670 } 671 }; 672 673 template <typename Scalar> 674 struct functor_traits<xlogy_op<Scalar>> { 675 enum { 676 Cost = functor_traits<scalar_log_op<Scalar>>::Cost + 677 Eigen::NumTraits<Scalar>::MulCost, 678 PacketAccess = functor_traits<scalar_log_op<Scalar>>::PacketAccess 679 }; 680 }; 681 682 template <typename Scalar> 683 struct xlog1py_op { 684 EIGEN_EMPTY_STRUCT_CTOR(xlog1py_op) 685 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar 686 operator()(const Scalar& x, const Scalar& y) const { 687 if (x == Scalar(0.)) { 688 return Scalar(0.); 689 } 690 return x * numext::log1p(y); 691 } 692 template <typename Packet> 693 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x, 694 const Packet& y) const { 695 Packet zeros = pzero(x); 696 Packet mask = pcmp_eq(x, zeros); 697 scalar_log1p_op<Scalar> log1p_op; 698 Packet log1p_y = log1p_op.packetOp(y); 699 Packet x_log1p_y = pmul(x, log1p_y); 700 return pselect(mask, x, x_log1p_y); 701 } 702 }; 703 704 template <typename Scalar> 705 struct functor_traits<xlog1py_op<Scalar>> { 706 enum { 707 Cost = functor_traits<scalar_log1p_op<Scalar>>::Cost + 708 Eigen::NumTraits<Scalar>::MulCost, 709 #if TENSORFLOW_USE_ROCM 710 PacketAccess = false, 711 #else 712 PacketAccess = functor_traits<scalar_log1p_op<Scalar>>::PacketAccess 713 #endif 714 }; 715 }; 716 717 template <typename Scalar> 718 struct xdivy_op { 719 EIGEN_EMPTY_STRUCT_CTOR(xdivy_op) 720 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar 721 operator()(const Scalar& x, const Scalar& y) const { 722 if (x == Scalar(0.)) { 723 return Scalar(0.); 724 } 725 return x / y; 726 } 727 template <typename Packet> 728 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x, 729 const Packet& y) const { 730 Packet zeros = pzero(x); 731 Packet mask = pcmp_eq(x, zeros); 732 Packet x_div_y = pdiv(x, y); 733 return pselect(mask, x, x_div_y); 734 } 735 }; 736 737 template <typename Scalar> 738 struct functor_traits<xdivy_op<Scalar>> { 739 enum { 740 Cost = 741 Eigen::NumTraits<Scalar>::AddCost + 742 Eigen::internal::scalar_div_cost<Scalar, 743 packet_traits<Scalar>::HasDiv>::value, 744 PacketAccess = packet_traits<Scalar>::HasDiv 745 }; 746 }; 747 748 template <typename T> 749 struct scalar_erfinv_op { 750 EIGEN_EMPTY_STRUCT_CTOR(scalar_erfinv_op) 751 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const { 752 constexpr T half = T(0.5); 753 T y = numext::ndtri(half * x + half); 754 constexpr T half_sqrt = T(M_SQRT1_2); 755 return y * half_sqrt; 756 } 757 template <typename Packet> 758 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const { 759 Packet half = pset1<Packet>(T(0.5)); 760 Packet y = pndtri<Packet>(pmadd(half, x, half)); 761 Packet half_sqrt = pset1<Packet>(T(M_SQRT1_2)); 762 return pmul(y, half_sqrt); 763 } 764 }; 765 766 template <typename T> 767 struct functor_traits<scalar_erfinv_op<T>> { 768 enum { 769 Cost = functor_traits<scalar_ndtri_op<T>>::Cost + NumTraits<T>::AddCost, 770 PacketAccess = packet_traits<T>::HasNdtri, 771 }; 772 }; 773 774 } // end namespace internal 775 } // end namespace Eigen 776 777 namespace tensorflow { 778 namespace functor { 779 780 //////////////////////////////////////////////////////////////////////////////// 781 // Helpers 782 //////////////////////////////////////////////////////////////////////////////// 783 784 // Base template for functors whose input scalar type is T and 785 // output scalar type is R. 786 template <typename T, typename F, typename R = T> 787 struct base { 788 // func defines operator() and its vectorized version packetOp(). 789 typedef F func; 790 791 // If true, the functor's corresponding binary op will instantiate 792 // specialized kernels to perform an optimized broadcast 793 // operation. Each functor for which this is enabled increases the 794 // code size, so by default this is disabled for binary functors and 795 // is enabled on a per-op basis as needed. 796 static constexpr bool use_bcast_optimization = false; 797 798 // operator() has the signature: 799 // out_type operator()(in_type in0, in_type in1 ...) 800 typedef R out_type; 801 typedef T in_type; 802 803 // TensorFlow provides tensor-ized version of "func". Roughly 804 // speaking, the tensorflow operation has the signature: 805 // tout_type op(tin_type in0) 806 // tout_type op(tin_type in0, tin_type in1) 807 // tout_type op(tin_type in0, in_type scalar) 808 typedef typename TTypes<out_type>::Flat tout_type; 809 typedef typename TTypes<in_type>::ConstFlat tin_type; 810 typedef typename TTypes<in_type>::ConstScalar tscalar_type; 811 812 // Whether the functor can error out. Currently applies only to integer 813 // div and mod. 814 static constexpr bool has_errors = false; 815 }; 816 817 // For now, we only apply certain speed optimization for 818 // float/double's broadcast binary op. 819 template <typename T> 820 struct use_bcast_optimization { 821 static constexpr bool value = false; 822 }; 823 824 template <> 825 struct use_bcast_optimization<float> { 826 static constexpr bool value = true; 827 }; 828 829 template <> 830 struct use_bcast_optimization<double> { 831 static constexpr bool value = true; 832 }; 833 834 //////////////////////////////////////////////////////////////////////////////// 835 // Unary functors 836 //////////////////////////////////////////////////////////////////////////////// 837 838 // abs(x) = |x| 839 // neg(x) = - x 840 // inverse(x) = 1 / x 841 // square(x) = x^2 842 // sqrt(x) = x^(1/2) 843 // rsqrt(x) = x^(-1/2) 844 // exp(x) = e^x 845 // expm1(x) = e^x - 1 846 // log(x) = natural logarithm of x 847 // log1p(x) = natural logarithm of 1 + x 848 // tanh = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) 849 // sigmoid = 1 / (1 + exp(-x)) // a.k.a, logistic 850 // 851 // NOTE: We may eventually implement common functions used in NN 852 // here. E.g., rectifier, softplus, derivatives of tanh, sigmod, etc. 853 // For reference, see speech/lstm/eigen_functors.h. 854 855 template <typename T> 856 struct abs : base<T, Eigen::internal::scalar_abs_op<T>, 857 typename Eigen::internal::scalar_abs_op<T>::result_type> {}; 858 859 template <typename T> 860 struct neg : base<T, Eigen::internal::scalar_opposite_op<T>> {}; 861 862 template <typename T> 863 struct inverse : base<T, Eigen::internal::scalar_inverse_op<T>> {}; 864 865 template <typename T> 866 struct square : base<T, Eigen::internal::scalar_square_op<T>> {}; 867 868 template <typename T> 869 struct sqrt : base<T, Eigen::internal::scalar_sqrt_op<T>> {}; 870 871 template <typename T> 872 struct rsqrt : base<T, Eigen::internal::scalar_rsqrt_op<T>> {}; 873 874 template <typename T> 875 struct exp : base<T, Eigen::internal::scalar_exp_op<T>> {}; 876 877 template <typename T> 878 struct expm1 : base<T, Eigen::internal::scalar_expm1_op<T>> {}; 879 880 template <typename T> 881 struct log : base<T, Eigen::internal::scalar_log_op<T>> {}; 882 883 template <typename T> 884 struct log1p : base<T, Eigen::internal::scalar_log1p_op<T>> {}; 885 886 template <typename T> 887 struct sign : base<T, Eigen::internal::scalar_sign_op<T>> {}; 888 889 template <typename T> 890 struct sinh : base<T, Eigen::internal::scalar_sinh_op<T>> {}; 891 892 template <typename T> 893 struct cosh : base<T, Eigen::internal::scalar_cosh_op<T>> {}; 894 895 template <typename T> 896 struct tanh : base<T, Eigen::internal::scalar_tanh_op<T>> {}; 897 898 template <typename T> 899 struct asinh : base<T, Eigen::internal::scalar_asinh_op<T>> {}; 900 901 template <typename T> 902 struct acosh : base<T, Eigen::internal::scalar_acosh_op<T>> {}; 903 904 template <typename T> 905 struct atanh : base<T, Eigen::internal::scalar_atanh_op<T>> {}; 906 907 template <typename T> 908 struct lgamma : base<T, Eigen::internal::scalar_lgamma_op<T>> {}; 909 910 template <typename T> 911 struct digamma : base<T, Eigen::internal::scalar_digamma_op<T>> {}; 912 913 template <typename T> 914 struct erf : base<T, Eigen::internal::scalar_erf_op<T>> {}; 915 916 template <typename T> 917 struct erfc : base<T, Eigen::internal::scalar_erfc_op<T>> {}; 918 919 template <typename T> 920 struct ndtri : base<T, Eigen::internal::scalar_ndtri_op<T>> {}; 921 922 template <typename T> 923 struct erfinv : base<T, Eigen::internal::scalar_erfinv_op<T>> {}; 924 925 template <typename T> 926 struct sigmoid : base<T, Eigen::internal::scalar_logistic_op<T>> {}; 927 928 template <typename T> 929 struct sin : base<T, Eigen::internal::scalar_sin_op<T>> {}; 930 931 template <typename T> 932 struct cos : base<T, Eigen::internal::scalar_cos_op<T>> {}; 933 934 template <typename T> 935 struct tan : base<T, Eigen::internal::scalar_tan_op<T>> {}; 936 937 template <typename T> 938 struct asin : base<T, Eigen::internal::scalar_asin_op<T>> {}; 939 940 template <typename T> 941 struct acos : base<T, Eigen::internal::scalar_acos_op<T>> {}; 942 943 template <typename T> 944 struct atan : base<T, Eigen::internal::scalar_atan_op<T>> {}; 945 946 struct logical_not : base<bool, Eigen::internal::scalar_boolean_not_op<bool>> { 947 }; 948 949 // Flip all bits. Named invert to be consistent with numpy. 950 template <typename T> 951 struct invert_op { 952 EIGEN_EMPTY_STRUCT_CTOR(invert_op) 953 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a) const { 954 return ~a; 955 } 956 }; 957 958 template <typename T> 959 struct invert : base<T, invert_op<T>> {}; 960 961 // NOTE: std::isinf, std::isnan, std::isfinite are plain function. 962 // Therefore we need to wrap them in functors to be used with Eigen's 963 // type system. 964 template <typename T> 965 struct isinf : base<T, Eigen::internal::scalar_isinf_op<T>, bool> {}; 966 967 template <typename T> 968 struct isnan : base<T, Eigen::internal::scalar_isnan_op<T>, bool> {}; 969 970 template <typename T> 971 struct isfinite : base<T, Eigen::internal::scalar_isfinite_op<T>, bool> {}; 972 973 template <typename T> 974 struct floor : base<T, Eigen::internal::scalar_floor_op<T>> {}; 975 976 template <typename T> 977 struct round : base<T, Eigen::internal::scalar_round_half_to_even_op<T>> {}; 978 979 template <typename T> 980 struct ceil : base<T, Eigen::internal::scalar_ceil_op<T>> {}; 981 982 // Note: rint rounds half values to even, just like round_half_to_even_op. 983 template <typename T> 984 struct rint : base<T, Eigen::internal::scalar_rint_op<T>> {}; 985 986 //////////////////////////////////////////////////////////////////////////////// 987 // Binary functors 988 //////////////////////////////////////////////////////////////////////////////// 989 990 // Binary functors: 991 // 992 // add(x, y) = x + y 993 // sub(x, y) = x - y 994 // mul(x, y) = x * y 995 // div(x, y) = x / y 996 // mod(x, y) = x % y (int32 and int64 only) 997 // fmod(x, y) = fmod(x, y) (float and double only) 998 // pow(x, y) = x ^ y 999 // maximum(x, y) = x > y ? x : y 1000 // minimum(x, y) = x < y ? x : y 1001 // squared_difference(x, y) = conj(x - y) * (x - y) 1002 1003 template <typename T> 1004 struct add : base<T, Eigen::internal::scalar_sum_op<T>> { 1005 static constexpr bool use_bcast_optimization = true; 1006 }; 1007 1008 template <typename T> 1009 struct sub : base<T, Eigen::internal::scalar_difference_op<T>> { 1010 static constexpr bool use_bcast_optimization = true; 1011 }; 1012 1013 template <typename T> 1014 struct mul : base<T, Eigen::internal::scalar_product_op<T>> { 1015 static constexpr bool use_bcast_optimization = true; 1016 }; 1017 1018 template <typename T> 1019 struct mul_no_nan : base<T, Eigen::internal::mul_no_nan_op<T>> {}; 1020 1021 template <typename T> 1022 struct div : base<T, Eigen::internal::scalar_quotient_op<T>> {}; 1023 1024 template <typename T> 1025 struct safe_div : base<T, Eigen::internal::safe_div_or_mod_op< 1026 T, Eigen::internal::scalar_quotient_op<T>>> { 1027 static constexpr bool has_errors = true; 1028 }; 1029 1030 template <typename T> 1031 struct div_no_nan : base<T, Eigen::internal::div_no_nan_op<T>> {}; 1032 1033 template <typename T> 1034 struct fmod : base<T, Eigen::internal::scalar_fmod_op<T>> {}; 1035 1036 template <typename T> 1037 struct mod : base<T, Eigen::internal::scalar_mod2_op<T>> {}; 1038 1039 template <typename T> 1040 struct safe_mod : base<T, Eigen::internal::safe_div_or_mod_op< 1041 T, Eigen::internal::scalar_mod2_op<T>>> { 1042 static constexpr bool has_errors = true; 1043 }; 1044 1045 template <typename T> 1046 struct floor_fmod : base<T, Eigen::internal::google_floor_fmod<T>> {}; 1047 1048 template <typename T> 1049 struct safe_floor_mod : base<T, Eigen::internal::safe_div_or_mod_op< 1050 T, Eigen::internal::google_floor_mod<T>>> { 1051 static constexpr bool has_errors = true; 1052 }; 1053 1054 template <typename T> 1055 struct floor_div : base<T, Eigen::internal::google_floor_div<T>> {}; 1056 1057 template <typename T> 1058 struct safe_floor_div : base<T, Eigen::internal::safe_div_or_mod_op< 1059 T, Eigen::internal::google_floor_div<T>>> { 1060 static constexpr bool has_errors = true; 1061 }; 1062 1063 template <typename T> 1064 struct floor_div_real : base<T, Eigen::internal::google_floor_div_real<T>> {}; 1065 1066 template <typename T> 1067 struct pow : base<T, Eigen::internal::scalar_pow_op<T, T>> {}; 1068 1069 template <typename T> 1070 struct safe_pow : base<T, Eigen::internal::safe_scalar_binary_pow_op<T, T>> { 1071 static constexpr bool has_errors = true; 1072 }; 1073 1074 // Version of safe_pow for integers which returns 0 if RHS is negative and LHS 1075 // is not 1 or -1. For use on GPUs, where we cannot raise an error. 1076 template <typename T> 1077 struct safe_pow_ignore_error_op { 1078 static_assert(std::is_integral<T>::value, "Integer type expected"); 1079 EIGEN_EMPTY_STRUCT_CTOR(safe_pow_ignore_error_op) 1080 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, 1081 const T& y) const { 1082 if (TF_PREDICT_FALSE(y < 0)) { 1083 if (x == T(-1)) { 1084 T trunc_mod = Eigen::internal::scalar_mod2_op<T>()(y, T(2)); 1085 return trunc_mod == T(-1) ? T(-1) : T(1); 1086 } 1087 return x == T(1) ? T(1) : T(0); 1088 } 1089 return Eigen::internal::scalar_pow_op<T, T>{}(x, y); 1090 } 1091 }; 1092 1093 template <typename T> 1094 struct safe_pow_ignore_error : base<T, safe_pow_ignore_error_op<T>> {}; 1095 1096 template <typename T> 1097 struct maximum 1098 : base<T, Eigen::internal::scalar_max_op<T, T, Eigen::PropagateNaN>> {}; 1099 1100 template <typename T> 1101 struct minimum 1102 : base<T, Eigen::internal::scalar_min_op<T, T, Eigen::PropagateNaN>> {}; 1103 1104 template <typename T> 1105 struct igamma : base<T, Eigen::internal::scalar_igamma_op<T>> {}; 1106 1107 template <typename T> 1108 struct random_gamma_grad 1109 : base<T, Eigen::internal::scalar_gamma_sample_der_alpha_op<T>> {}; 1110 1111 template <typename T> 1112 struct igammac : base<T, Eigen::internal::scalar_igammac_op<T>> {}; 1113 1114 template <typename T> 1115 struct zeta : base<T, Eigen::internal::scalar_zeta_op<T>> {}; 1116 1117 template <typename T> 1118 struct polygamma : base<T, Eigen::internal::scalar_polygamma_op<T>> {}; 1119 1120 template <typename Scalar> 1121 struct scalar_atan2_op { 1122 EIGEN_EMPTY_STRUCT_CTOR(scalar_atan2_op) 1123 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar 1124 operator()(const Scalar& y, const Scalar& x) const { 1125 #if TENSORFLOW_USE_ROCM 1126 return ::atan2(y, x); 1127 #else 1128 return std::atan2(y, x); 1129 #endif 1130 } 1131 }; 1132 1133 template <typename T> 1134 struct atan2 : base<T, scalar_atan2_op<T>> {}; 1135 1136 template <typename T> 1137 struct squared_difference 1138 : base<T, Eigen::internal::scalar_squared_difference_op<T>> {}; 1139 1140 template <typename T> 1141 struct xdivy : base<T, Eigen::internal::xdivy_op<T>> {}; 1142 1143 template <typename T> 1144 struct xlogy : base<T, Eigen::internal::xlogy_op<T>> {}; 1145 1146 template <typename T> 1147 struct xlog1py : base<T, Eigen::internal::xlog1py_op<T>> {}; 1148 1149 template <typename T> 1150 struct less : base<T, Eigen::internal::less<T>, bool> {}; 1151 1152 template <typename T> 1153 struct less_equal : base<T, Eigen::internal::less_equal<T>, bool> {}; 1154 1155 template <typename T> 1156 struct greater : base<T, Eigen::internal::greater<T>, bool> {}; 1157 1158 template <typename T> 1159 struct greater_equal : base<T, Eigen::internal::greater_equal<T>, bool> {}; 1160 1161 template <typename T> 1162 struct equal_to : base<T, Eigen::internal::equal_to<T>, bool> {}; 1163 1164 template <typename T> 1165 struct not_equal_to : base<T, Eigen::internal::not_equal_to<T>, bool> {}; 1166 1167 struct logical_and : base<bool, Eigen::internal::scalar_boolean_and_op> {}; 1168 1169 struct logical_or : base<bool, Eigen::internal::scalar_boolean_or_op> {}; 1170 1171 template <typename T> 1172 struct bitwise_and_op { 1173 EIGEN_EMPTY_STRUCT_CTOR(bitwise_and_op) 1174 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, 1175 const T& y) const { 1176 return x & y; 1177 } 1178 }; 1179 1180 template <typename T> 1181 struct bitwise_or_op { 1182 EIGEN_EMPTY_STRUCT_CTOR(bitwise_or_op) 1183 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, 1184 const T& y) const { 1185 return x | y; 1186 } 1187 }; 1188 1189 template <typename T> 1190 struct bitwise_and : base<T, bitwise_and_op<T>> {}; 1191 1192 template <typename T> 1193 struct bitwise_or : base<T, bitwise_or_op<T>> {}; 1194 1195 template <typename T> 1196 struct bitwise_xor : base<T, Eigen::internal::bitwise_xor_op<T>> {}; 1197 1198 template <typename T> 1199 struct left_shift_op { 1200 EIGEN_EMPTY_STRUCT_CTOR(left_shift_op) 1201 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, 1202 const T& y) const { 1203 // Avoids UB: don't shift by larger than the bitwidth of T, and 1204 // performs left shifts as unsigned shifts. 1205 T y_clamped = y; 1206 if (y_clamped < 0) { 1207 y_clamped = 0; 1208 } else if (y_clamped > sizeof(T) * CHAR_BIT - 1) { 1209 y_clamped = sizeof(T) * CHAR_BIT - 1; 1210 } 1211 using U = typename std::make_unsigned<T>::type; 1212 return static_cast<T>(static_cast<U>(x) << static_cast<U>(y_clamped)); 1213 } 1214 }; 1215 1216 template <typename T> 1217 struct right_shift_op { 1218 EIGEN_EMPTY_STRUCT_CTOR(right_shift_op) 1219 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, 1220 const T& y) const { 1221 // Avoids UB: don't shift by larger than the bitwidth of T. 1222 T y_clamped = y; 1223 if (y_clamped < 0) { 1224 y_clamped = 0; 1225 } else if (y_clamped > sizeof(T) * CHAR_BIT - 1) { 1226 y_clamped = sizeof(T) * CHAR_BIT - 1; 1227 } 1228 // Technically right shifts of signed integers are not necessarily 1229 // arithmetic shifts according to the C++ standard. However in practice most 1230 // implementations are arithmetic shifts. If this proves to be a problem in 1231 // practice, we may need to use an alternative implementation. 1232 return x >> y_clamped; 1233 } 1234 }; 1235 1236 template <typename T> 1237 struct left_shift : base<T, left_shift_op<T>> {}; 1238 1239 template <typename T> 1240 struct right_shift : base<T, right_shift_op<T>> {}; 1241 1242 template <typename T> 1243 struct make_complex_func { 1244 typedef std::complex<T> result_type; 1245 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(T real, 1246 T imag) const { 1247 return std::complex<T>(real, imag); 1248 } 1249 }; 1250 1251 template <typename T> 1252 struct make_complex : base<T, make_complex_func<T>, std::complex<T>> {}; 1253 1254 template <typename T> 1255 struct get_real 1256 : base<T, Eigen::internal::scalar_real_op<T>, typename T::value_type> {}; 1257 1258 template <typename T> 1259 struct get_imag 1260 : base<T, Eigen::internal::scalar_imag_op<T>, typename T::value_type> {}; 1261 1262 template <typename T> 1263 struct get_angle 1264 : base<T, Eigen::internal::scalar_arg_op<T>, typename T::value_type> {}; 1265 1266 template <typename T> 1267 struct conj : base<T, Eigen::internal::scalar_conjugate_op<T>> {}; 1268 1269 //////////////////////////////////////////////////////////////////////////////// 1270 // Functors takes 1 or 2 tensors, computes the base functor on 1271 // coefficient of the input tensors and puts the results in the output 1272 // tensor. 1273 //////////////////////////////////////////////////////////////////////////////// 1274 template <typename Device, typename Functor> 1275 struct UnaryFunctor { 1276 // Computes on device "d": out[i] = Functor(in[i]) 1277 void operator()(const Device& d, typename Functor::tout_type out, 1278 typename Functor::tin_type in); 1279 }; 1280 1281 template <typename Device, typename Functor, int NDIMS, 1282 bool has_errors = Functor::has_errors> 1283 struct BinaryFunctor { 1284 // Computes on device "d": out[i] = Functor(in0[i], in1[i]) 1285 void operator()(const Device& d, typename Functor::tout_type out, 1286 typename Functor::tin_type in0, 1287 typename Functor::tin_type in1, bool* error); 1288 1289 // Computes on device "d": out[i] = Functor(scalar[0], in[i]) 1290 void Left(const Device& d, typename Functor::tout_type out, 1291 typename Functor::tscalar_type scalar, 1292 typename Functor::tin_type in, bool* error); 1293 1294 // Computes on device "d": out[i] = Functor(in[i], scalar[0]) 1295 void Right(const Device& d, typename Functor::tout_type out, 1296 typename Functor::tin_type in, 1297 typename Functor::tscalar_type scalar, bool* error); 1298 1299 // Computes on device "d": 1300 // out = Functor(in0.broadcast(bcast0), in1.broadcast(bcast1)) 1301 // 1302 // TODO(zhifengc): makes BCast a template member function on NDIMS 1303 // instead making BinaryFunctor templates on NDIMS. 1304 void BCast(const Device& d, 1305 typename TTypes<typename Functor::out_type, NDIMS>::Tensor out, 1306 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0, 1307 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0, 1308 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1, 1309 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1, 1310 bool* error); 1311 }; 1312 1313 template <typename Device, typename T> 1314 struct ApproximateEqual { 1315 void operator()(const Device& d, typename TTypes<T>::ConstFlat x, 1316 typename TTypes<T>::ConstFlat y, T tolerance, 1317 typename TTypes<bool>::Flat z); 1318 }; 1319 1320 template <int NDIMS> 1321 bool AllOne(const typename Eigen::array<Eigen::DenseIndex, NDIMS>& a) { 1322 for (size_t i = 0; i < a.size(); ++i) { 1323 if (a[i] != 1) return false; 1324 } 1325 return true; 1326 } 1327 1328 template <typename Device, typename T> 1329 struct SelectFunctor { 1330 void operator()(const Device& d, typename TTypes<T>::Flat out, 1331 typename TTypes<bool>::ConstFlat cond_flat, 1332 typename TTypes<T>::ConstFlat then_flat, 1333 typename TTypes<T>::ConstFlat else_flat); 1334 }; 1335 1336 template <typename Device, typename T> 1337 struct SelectScalarFunctor { 1338 void operator()(const Device& d, typename TTypes<T>::Flat out, 1339 typename TTypes<bool>::ConstScalar cond, 1340 typename TTypes<T>::ConstFlat then_flat, 1341 typename TTypes<T>::ConstFlat else_flat); 1342 }; 1343 1344 template <typename Device, typename T> 1345 struct BatchSelectFunctor { 1346 void operator()(const Device& d, 1347 typename TTypes<T>::Matrix output_flat_outer_dims, 1348 TTypes<bool>::ConstVec cond_vec, 1349 typename TTypes<T>::ConstMatrix then_flat_outer_dims, 1350 typename TTypes<T>::ConstMatrix else_flat_outer_dims); 1351 }; 1352 1353 template <typename Device, typename T, int NDIMS> 1354 struct BCastSelectFunctor { 1355 void operator()(const Device& d, 1356 typename TTypes<T, NDIMS>::Tensor output_tensor, 1357 typename TTypes<bool, NDIMS>::ConstTensor cond_tensor, 1358 typename TTypes<T, NDIMS>::ConstTensor then_tensor, 1359 typename TTypes<T, NDIMS>::ConstTensor else_tensor, 1360 typename Eigen::array<Eigen::DenseIndex, NDIMS> cond_bcast, 1361 typename Eigen::array<Eigen::DenseIndex, NDIMS> then_bcast, 1362 typename Eigen::array<Eigen::DenseIndex, NDIMS> else_bcast); 1363 }; 1364 1365 } // end namespace functor 1366 } // end namespace tensorflow 1367 1368 #endif // TENSORFLOW_CORE_KERNELS_CWISE_OPS_H_ 1369