1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_H_ 17 #define TENSORFLOW_CORE_KERNELS_CWISE_OPS_H_ 18 19 #include <cmath> 20 #include <functional> 21 #include <type_traits> 22 23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 24 25 #include "tensorflow/core/framework/bounds_check.h" 26 #include "tensorflow/core/framework/numeric_types.h" 27 #include "tensorflow/core/framework/tensor_types.h" 28 29 namespace Eigen { 30 namespace internal { 31 32 #if GOOGLE_CUDA 33 template <> 34 struct scalar_arg_op<std::complex<float>> { 35 EIGEN_EMPTY_STRUCT_CTOR(scalar_arg_op) 36 typedef typename Eigen::NumTraits<std::complex<float>>::Real result_type; 37 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const float operator()( 38 const std::complex<float>& a) const { 39 return ::atan2f(a.imag(), a.real()); 40 } 41 }; 42 43 template <> 44 struct scalar_arg_op<std::complex<double>> { 45 EIGEN_EMPTY_STRUCT_CTOR(scalar_arg_op) 46 typedef typename Eigen::NumTraits<std::complex<double>>::Real result_type; 47 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const double operator()( 48 const std::complex<double>& a) const { 49 return ::atan2(a.imag(), a.real()); 50 } 51 }; 52 #endif 53 54 #if EIGEN_HAS_CXX11_MATH == 0 55 template <typename T> 56 struct scalar_asinh_op { 57 EIGEN_EMPTY_STRUCT_CTOR(scalar_asinh_op) 58 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a) const { 59 return std::asinh(a); 60 } 61 }; 62 template <typename T> 63 struct functor_traits<scalar_asinh_op<T>> { 64 enum { Cost = 5 * NumTraits<T>::MulCost, PacketAccess = false }; 65 }; 66 67 template <typename T> 68 struct scalar_acosh_op { 69 EIGEN_EMPTY_STRUCT_CTOR(scalar_acosh_op) 70 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a) const { 71 return std::acosh(a); 72 } 73 }; 74 template <typename T> 75 struct functor_traits<scalar_acosh_op<T>> { 76 enum { Cost = 5 * NumTraits<T>::MulCost, PacketAccess = false }; 77 }; 78 79 template <typename T> 80 struct scalar_atanh_op { 81 EIGEN_EMPTY_STRUCT_CTOR(scalar_atanh_op) 82 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a) const { 83 return std::atanh(a); 84 } 85 }; 86 template <typename T> 87 struct functor_traits<scalar_atanh_op<T>> { 88 enum { Cost = 5 * NumTraits<T>::MulCost, PacketAccess = false }; 89 }; 90 #endif 91 92 template <typename Scalar, typename Exponent> 93 struct safe_scalar_binary_pow_op { 94 static_assert(std::is_integral<Scalar>::value, "Integer type expected"); 95 static_assert(std::is_integral<Exponent>::value && 96 std::is_signed<Exponent>::value, 97 "Signed integer type expected"); 98 99 bool* const error; 100 101 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_scalar_binary_pow_op(bool* error) 102 : error(error) {} 103 104 EIGEN_DEVICE_FUNC inline Scalar operator()(const Scalar& a, 105 const Exponent& b) const { 106 const Exponent safe_b = tensorflow::internal::SubtleMustCopy(b); 107 if (TF_PREDICT_TRUE(safe_b >= 0)) { 108 return numext::pow(a, safe_b); 109 } else { 110 *error = true; 111 return 0; 112 } 113 } 114 }; 115 116 template <typename Scalar, typename Exponent> 117 struct functor_traits<safe_scalar_binary_pow_op<Scalar, Exponent>> { 118 enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false }; 119 }; 120 121 template <typename T, typename DivOrMod> 122 struct safe_div_or_mod_op { 123 static_assert(std::is_integral<T>::value, "Integer type expected"); 124 125 bool* const error; 126 127 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_div_or_mod_op(bool* error) 128 : error(error) {} 129 130 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a, 131 const T& b) const { 132 const T safe_b = tensorflow::internal::SubtleMustCopy(b); 133 if (TF_PREDICT_TRUE(safe_b != 0)) { 134 return DivOrMod()(a, safe_b); 135 } else { 136 *error = true; 137 return 0; 138 } 139 } 140 }; 141 142 template <typename T, typename DivOrMod> 143 struct functor_traits<safe_div_or_mod_op<T, DivOrMod>> { 144 enum { 145 Cost = functor_traits<DivOrMod>::Cost + NumTraits<T>::AddCost, 146 PacketAccess = false, 147 }; 148 }; 149 150 template <typename T, typename Binary> 151 struct no_nan_op { 152 EIGEN_EMPTY_STRUCT_CTOR(no_nan_op) 153 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a, 154 const T& b) const { 155 if (b != T(0)) { 156 return Binary()(a, b); 157 } else { 158 return T(0); 159 } 160 } 161 template <typename Packet> 162 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet 163 packetOp(const Packet& a, const Packet& b) const { 164 const Packet mask = pcmp_eq(b, pzero(b)); 165 const Packet quotient = Binary().packetOp(a, b); 166 return pandnot(quotient, mask); 167 } 168 }; 169 170 template <typename T> 171 struct div_no_nan_op : public no_nan_op<T, scalar_quotient_op<T>> { 172 EIGEN_EMPTY_STRUCT_CTOR(div_no_nan_op) 173 }; 174 175 template <typename T> 176 struct functor_traits<div_no_nan_op<T>> { 177 enum { 178 Cost = functor_traits<scalar_quotient_op<T>>::Cost + NumTraits<T>::AddCost, 179 PacketAccess = true, 180 }; 181 }; 182 183 template <typename T> 184 struct mul_no_nan_op : public no_nan_op<T, scalar_product_op<T>> { 185 EIGEN_EMPTY_STRUCT_CTOR(mul_no_nan_op) 186 }; 187 188 template <typename T> 189 struct functor_traits<mul_no_nan_op<T>> { 190 enum { 191 Cost = functor_traits<scalar_product_op<T>>::Cost + NumTraits<T>::AddCost, 192 PacketAccess = true, 193 }; 194 }; 195 196 // scalar_left and scalar_right are template helpers to partially 197 // apply a binary function. 198 // 199 // Suppose Binary is a binary functor f(x, y), scalar_left<> is a 200 // unary functor g_x(y) = f(x, y), where x is provided via the 201 // constructor. Similarly, scalar_right<> is a unary functor g_y(x) = 202 // f(x, y). 203 204 template <typename Tout, typename Tin, typename Binary> 205 struct scalar_left : private Binary { 206 typedef Tout result_type; 207 const Tin* left; 208 209 EIGEN_DEVICE_FUNC inline scalar_left(const scalar_left& other) = default; 210 211 template <typename... Args> 212 EIGEN_DEVICE_FUNC inline explicit scalar_left(const Tin* c, Args... args) 213 : Binary(args...), left(c) {} 214 215 EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& right) const { 216 return Binary::operator()(*left, right); 217 } 218 219 template <typename Packet> 220 EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& right_packet) const { 221 const Packet left_packet = Eigen::internal::pset1<Packet>(*left); 222 return Binary::packetOp(left_packet, right_packet); 223 } 224 }; 225 226 template <typename Tout, typename Tin, typename Binary> 227 struct functor_traits<scalar_left<Tout, Tin, Binary>> { 228 enum { 229 Cost = functor_traits<Binary>::Cost, 230 PacketAccess = functor_traits<Binary>::PacketAccess, 231 }; 232 }; 233 234 template <typename Tout, typename Tin, typename Binary> 235 struct scalar_right : private Binary { 236 typedef Tout result_type; 237 const Tin* right; 238 239 EIGEN_DEVICE_FUNC inline scalar_right(const scalar_right& other) = default; 240 241 template <typename... Args> 242 EIGEN_DEVICE_FUNC inline explicit scalar_right(const Tin* c, Args... args) 243 : Binary(args...), right(c) {} 244 245 EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& left) const { 246 return Binary::operator()(left, *right); 247 } 248 249 template <typename Packet> 250 EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& left_packet) const { 251 const Packet right_packet = Eigen::internal::pset1<Packet>(*right); 252 return Binary::packetOp(left_packet, right_packet); 253 } 254 }; 255 256 template <typename Tout, typename Tin, typename Binary> 257 struct functor_traits<scalar_right<Tout, Tin, Binary>> { 258 enum { 259 Cost = functor_traits<Binary>::Cost, 260 PacketAccess = functor_traits<Binary>::PacketAccess, 261 }; 262 }; 263 264 // similar to std::equal_to, but with the DEVICE_FUNC qualifier 265 template <class T> 266 struct equal_to : std::binary_function<T, T, bool> { 267 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, 268 const T& y) const { 269 return x == y; 270 } 271 }; 272 273 // similar to std::not_equal_to, but with the DEVICE_FUNC qualifier 274 template <class T> 275 struct not_equal_to : std::binary_function<T, T, bool> { 276 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, 277 const T& y) const { 278 return x != y; 279 } 280 }; 281 282 // similar to std::greater, but with the DEVICE_FUNC qualifier 283 template <class T> 284 struct greater : std::binary_function<T, T, bool> { 285 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, 286 const T& y) const { 287 return x > y; 288 } 289 }; 290 291 // similar to std::less, but with the DEVICE_FUNC qualifier 292 template <class T> 293 struct less : std::binary_function<T, T, bool> { 294 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, 295 const T& y) const { 296 return x < y; 297 } 298 }; 299 300 // similar to std::greater_equal, but with the DEVICE_FUNC qualifier 301 template <class T> 302 struct greater_equal : std::binary_function<T, T, bool> { 303 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, 304 const T& y) const { 305 return x >= y; 306 } 307 }; 308 309 // similar to std::less_equal, but with the DEVICE_FUNC qualifier 310 template <class T> 311 struct less_equal : std::binary_function<T, T, bool> { 312 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, 313 const T& y) const { 314 return x <= y; 315 } 316 }; 317 318 // Functor that enables squared difference functor. 319 template <typename Scalar> 320 struct scalar_squared_difference_op { 321 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar 322 operator()(const Scalar& a, const Scalar& b) const { 323 const Scalar v = scalar_difference_op<Scalar>()(a, b); 324 return scalar_product_op<Scalar>()(v, scalar_conjugate_op<Scalar>()(v)); 325 } 326 template <typename Packet> 327 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet 328 packetOp(const Packet& a, const Packet& b) const { 329 const Packet v = scalar_difference_op<Scalar>().packetOp(a, b); 330 return scalar_product_op<Scalar>().packetOp( 331 v, scalar_conjugate_op<Scalar>().packetOp(v)); 332 } 333 }; 334 335 template <typename Scalar> 336 struct functor_traits<scalar_squared_difference_op<Scalar>> { 337 enum { 338 Cost = functor_traits<scalar_difference_op<Scalar>>::Cost + 339 functor_traits<scalar_conjugate_op<Scalar>>::Cost + 340 functor_traits<scalar_product_op<Scalar>>::Cost, 341 PacketAccess = functor_traits<scalar_difference_op<Scalar>>::PacketAccess && 342 functor_traits<scalar_conjugate_op<Scalar>>::PacketAccess && 343 functor_traits<scalar_product_op<Scalar>>::PacketAccess 344 }; 345 }; 346 347 // TODO(b/32239616): This kernel should be moved into Eigen and vectorized. 348 template <typename T, typename Enable = void> 349 struct google_floor_div { 350 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, 351 const T& y) const { 352 if ((x < T(0)) != (y < T(0))) { 353 T abs_x = std::abs(x); 354 T abs_y = std::abs(y); 355 return -(abs_x + abs_y - 1) / abs_y; 356 } else { 357 return x / y; 358 } 359 } 360 template <typename Packet> 361 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet 362 packetOp(const Packet& x, const Packet& y) const { 363 Packet zeros = pzero(x); 364 Packet x_mask = pcmp_lt(x, zeros); 365 Packet y_mask = pcmp_lt(y, zeros); 366 Packet x_div_y = pdiv(x, y); 367 Packet abs_x = pabs(x); 368 Packet abs_y = pabs(y); 369 Packet ones = pones(x); 370 Packet ratio_rounded = pdiv(pnegate(psub(padd(abs_x, abs_y), ones)), abs_y); 371 return pselect(pxor(x_mask, y_mask), ratio_rounded, x_div_y); 372 } 373 }; 374 375 template <typename T> 376 struct google_floor_div< 377 T, typename std::enable_if<std::is_unsigned<T>::value>::type> { 378 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, 379 const T& y) const { 380 return x / y; 381 } 382 template <typename Packet> 383 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet 384 packetOp(const Packet& x, const Packet& y) const { 385 return pdiv(x, y); 386 } 387 }; 388 389 template <typename Scalar> 390 struct functor_traits<google_floor_div<Scalar>> { 391 enum { 392 Cost = 2 * Eigen::internal::scalar_div_cost< 393 Scalar, packet_traits<Scalar>::HasDiv>::value + 394 NumTraits<Scalar>::AddCost, 395 PacketAccess = packet_traits<Scalar>::HasDiv 396 }; 397 }; 398 399 template <typename T, typename Enable = void> 400 struct google_floor_div_real { 401 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, 402 const T& y) const { 403 return Eigen::numext::floor(x / y); 404 } 405 template <typename Packet> 406 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet 407 packetOp(const Packet& x, const Packet& y) const { 408 return pfloor(pdiv(x, y)); 409 } 410 }; 411 412 template <typename Scalar> 413 struct functor_traits<google_floor_div_real<Scalar>> { 414 enum { 415 Cost = 2 * Eigen::internal::scalar_div_cost< 416 Scalar, packet_traits<Scalar>::HasDiv>::value + 417 2 * NumTraits<Scalar>::AddCost, 418 PacketAccess = 419 packet_traits<Scalar>::HasDiv && packet_traits<Scalar>::HasFloor 420 }; 421 }; 422 423 // TODO(rmlarsen): Add vectorized mod & fmod in Eigen and use it here. 424 template <typename T> 425 struct google_floor_fmod { 426 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, 427 const T& y) const { 428 // EIGEN_STATIC_ASSERT(NUMERIC_TYPE_MUST_BE_REAL); 429 T trunc_mod = std::fmod(x, y); 430 return (x < T(0)) == (y < T(0)) ? trunc_mod : std::fmod(trunc_mod + y, y); 431 } 432 }; 433 434 template <typename Scalar> 435 struct functor_traits<google_floor_fmod<Scalar>> { 436 enum { 437 Cost = 2 * Eigen::internal::scalar_div_cost<Scalar, false>::value + 438 2 * NumTraits<Scalar>::AddCost, 439 PacketAccess = false 440 }; 441 }; 442 443 // TODO(rmlarsen): Add vectorized mod & fmod in Eigen and use it here. 444 template <typename T> 445 struct google_floor_mod { 446 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, 447 const T& y) const { 448 // EIGEN_STATIC_ASSERT(!NUMERIC_TYPE_MUST_BE_REAL); 449 T trunc_mod = x % y; 450 return (x < T(0)) == (y < T(0)) ? trunc_mod : (trunc_mod + y) % y; 451 } 452 }; 453 454 template <typename Scalar> 455 struct functor_traits<google_floor_mod<Scalar>> { 456 enum { 457 Cost = 2 * Eigen::internal::scalar_div_cost<Scalar, false>::value + 458 2 * NumTraits<Scalar>::AddCost, 459 PacketAccess = false 460 }; 461 }; 462 463 #if EIGEN_COMP_GNUC && __cplusplus > 199711L 464 #define DISABLE_FLOAT_EQUALITY_WARNING \ 465 _Pragma("GCC diagnostic push") \ 466 _Pragma("GCC diagnostic ignored \"-Wfloat-equal\"") 467 #define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop") 468 #else 469 #define DISABLE_FLOAT_EQUALITY_WARNING 470 #define ENABLE_FLOAT_EQUALITY_WARNING 471 #endif 472 473 template <typename Scalar, bool IsInteger = Eigen::NumTraits<Scalar>::IsInteger> 474 struct scalar_round_op_google { 475 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar 476 operator()(const Scalar& x) const { 477 EIGEN_STATIC_ASSERT((!NumTraits<Scalar>::IsComplex), 478 NUMERIC_TYPE_MUST_BE_REAL) 479 480 Scalar round_val = Eigen::numext::floor(x); 481 const Scalar fraction = x - round_val; 482 if (fraction > Scalar(.5)) { 483 round_val += Scalar(1.0); 484 } else if (fraction == Scalar(.5)) { 485 const Scalar nearest_even_int = 486 round_val - Scalar(2) * Eigen::numext::floor(Scalar(.5) * x); 487 bool is_odd = (nearest_even_int == Scalar(1)); 488 if (is_odd) { 489 round_val += Scalar(1); 490 } 491 } 492 return round_val; 493 } 494 }; 495 496 template <typename Scalar> 497 struct scalar_round_op_google<Scalar, true> { 498 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar 499 operator()(const Scalar& x) const { 500 return x; 501 } 502 template <typename Packet> 503 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet 504 packetOp(const Packet& x) const { 505 return x; 506 } 507 }; 508 509 template <typename Scalar> 510 struct functor_traits<scalar_round_op_google<Scalar>> { 511 enum { 512 Cost = Eigen::NumTraits<Scalar>::IsInteger ? 0 513 : 4 * NumTraits<Scalar>::AddCost, 514 PacketAccess = Eigen::NumTraits<Scalar>::IsInteger 515 }; 516 }; 517 518 template <typename Scalar, bool IsInteger = Eigen::NumTraits<Scalar>::IsInteger> 519 struct scalar_round_up_op { 520 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar 521 operator()(const Scalar& x) const { 522 EIGEN_STATIC_ASSERT((!NumTraits<Scalar>::IsComplex), 523 NUMERIC_TYPE_MUST_BE_REAL) 524 return Eigen::numext::floor(x + Scalar(0.5)); 525 } 526 527 template <typename Packet> 528 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet 529 packetOp(const Packet& x) const { 530 return pfloor(padd(x, pset1<Packet>(0.5))); 531 } 532 }; 533 534 template <typename Scalar> 535 struct scalar_round_up_op<Scalar, true> { 536 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar 537 operator()(const Scalar& x) const { 538 return x; 539 } 540 541 template <typename Packet> 542 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet 543 packetOp(const Packet& x) const { 544 return x; 545 } 546 }; 547 548 template <typename Scalar, bool IsInteger> 549 struct functor_traits<scalar_round_up_op<Scalar, IsInteger>> { 550 enum { 551 Cost = IsInteger ? 0 : 4 * NumTraits<Scalar>::AddCost, 552 PacketAccess = IsInteger || packet_traits<Scalar>::HasFloor 553 }; 554 }; 555 556 #undef ENABLE_FLOAT_EQUALITY_WARNING 557 #undef DISABLE_FLOAT_EQUALITY_WARNING 558 559 template <typename Scalar> 560 struct bitwise_xor_op { 561 EIGEN_EMPTY_STRUCT_CTOR(bitwise_xor_op) 562 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar 563 operator()(const Scalar& x, const Scalar& y) const { 564 return x ^ y; 565 } 566 template <typename Packet> 567 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet 568 packetOp(const Packet& a, const Packet& b) const { 569 return Eigen::internal::pxor(a, b); 570 } 571 }; 572 573 template <typename Scalar> 574 struct functor_traits<bitwise_xor_op<Scalar>> { 575 enum { Cost = Eigen::NumTraits<Scalar>::AddCost, PacketAccess = true }; 576 }; 577 578 template <typename Scalar> 579 struct xlogy_op { 580 EIGEN_EMPTY_STRUCT_CTOR(xlogy_op) 581 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar 582 operator()(const Scalar& x, const Scalar& y) const { 583 if (x == Scalar(0.)) { 584 return Scalar(0.); 585 } 586 return x * numext::log(y); 587 } 588 template <typename Packet> 589 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet 590 packetOp(const Packet& x, const Packet& y) const { 591 Packet zeros = pzero(x); 592 Packet mask = pcmp_eq(x, zeros); 593 scalar_log_op<Scalar> log_op; 594 Packet log_y = log_op.packetOp(y); 595 Packet x_log_y = pmul(x, log_y); 596 return pselect(mask, x, x_log_y); 597 } 598 }; 599 600 template <typename Scalar> 601 struct functor_traits<xlogy_op<Scalar>> { 602 enum { 603 Cost = functor_traits<scalar_log_op<Scalar>>::Cost + 604 Eigen::NumTraits<Scalar>::MulCost, 605 PacketAccess = functor_traits<scalar_log_op<Scalar>>::PacketAccess 606 }; 607 }; 608 609 template <typename Scalar> 610 struct xdivy_op { 611 EIGEN_EMPTY_STRUCT_CTOR(xdivy_op) 612 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar 613 operator()(const Scalar& x, const Scalar& y) const { 614 if (x == Scalar(0.)) { 615 return Scalar(0.); 616 } 617 return x / y; 618 } 619 template <typename Packet> 620 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet 621 packetOp(const Packet& x, const Packet& y) const { 622 Packet zeros = pzero(x); 623 Packet mask = pcmp_eq(x, zeros); 624 Packet x_div_y = pdiv(x, y); 625 return pselect(mask, x, x_div_y); 626 } 627 }; 628 629 template <typename Scalar> 630 struct functor_traits<xdivy_op<Scalar>> { 631 enum { 632 Cost = 633 Eigen::NumTraits<Scalar>::AddCost + 634 Eigen::internal::scalar_div_cost<Scalar, 635 packet_traits<Scalar>::HasDiv>::value, 636 PacketAccess = packet_traits<Scalar>::HasDiv 637 }; 638 }; 639 640 } // end namespace internal 641 } // end namespace Eigen 642 643 namespace tensorflow { 644 namespace functor { 645 646 //////////////////////////////////////////////////////////////////////////////// 647 // Helpers 648 //////////////////////////////////////////////////////////////////////////////// 649 650 // Base template for functors whose input scalar type is T and 651 // output scalar type is R. 652 template <typename T, typename F, typename R = T> 653 struct base { 654 // func defines operator() and its vectorized version packetOp(). 655 typedef F func; 656 657 // If true, the functor's corresponding binary op will instantiate 658 // specialized kernels to perform an optimized broadcast 659 // operation. Each functor for which this is enabled increases the 660 // code size, so by default this is disabled for binary functors and 661 // is enabled on a per-op basis as needed. 662 static const bool use_bcast_optimization = false; 663 664 // operator() has the signature: 665 // out_type operator()(in_type in0, in_type in1 ...) 666 typedef R out_type; 667 typedef T in_type; 668 669 // TensorFlow provides tensor-ized version of "func". Roughly 670 // speaking, the tensorflow operation has the signature: 671 // tout_type op(tin_type in0) 672 // tout_type op(tin_type in0, tin_type in1) 673 // tout_type op(tin_type in0, in_type scalar) 674 typedef typename TTypes<out_type>::Flat tout_type; 675 typedef typename TTypes<in_type>::ConstFlat tin_type; 676 typedef typename TTypes<in_type>::ConstScalar tscalar_type; 677 678 // Whether the functor can error out. Currently applies only to integer 679 // div and mod. 680 static const bool has_errors = false; 681 }; 682 683 // For now, we only apply certain speed optimization for 684 // float/double's broadcast binary op. 685 template <typename T> 686 struct use_bcast_optimization { 687 static const bool value = false; 688 }; 689 690 template <> 691 struct use_bcast_optimization<float> { 692 static const bool value = true; 693 }; 694 695 template <> 696 struct use_bcast_optimization<double> { 697 static const bool value = true; 698 }; 699 700 //////////////////////////////////////////////////////////////////////////////// 701 // Unary functors 702 //////////////////////////////////////////////////////////////////////////////// 703 704 // abs(x) = |x| 705 // neg(x) = - x 706 // inverse(x) = 1 / x 707 // square(x) = x^2 708 // sqrt(x) = x^(1/2) 709 // rsqrt(x) = x^(-1/2) 710 // exp(x) = e^x 711 // expm1(x) = e^x - 1 712 // log(x) = natural logarithm of x 713 // log1p(x) = natural logarithm of 1 + x 714 // tanh = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) 715 // sigmoid = 1 / (1 + exp(-x)) // a.k.a, logistic 716 // 717 // NOTE: We may eventually implement common functions used in NN 718 // here. E.g., rectifier, softplus, derivatives of tanh, sigmod, etc. 719 // For reference, see speech/lstm/eigen_functors.h. 720 721 template <typename T> 722 struct abs : base<T, Eigen::internal::scalar_abs_op<T>, 723 typename Eigen::internal::scalar_abs_op<T>::result_type> {}; 724 725 template <typename T> 726 struct neg : base<T, Eigen::internal::scalar_opposite_op<T>> {}; 727 728 template <typename T> 729 struct inverse : base<T, Eigen::internal::scalar_inverse_op<T>> {}; 730 731 template <typename T> 732 struct square : base<T, Eigen::internal::scalar_square_op<T>> {}; 733 734 template <typename T> 735 struct sqrt : base<T, Eigen::internal::scalar_sqrt_op<T>> {}; 736 737 template <typename T> 738 struct rsqrt : base<T, Eigen::internal::scalar_rsqrt_op<T>> {}; 739 740 template <typename T> 741 struct exp : base<T, Eigen::internal::scalar_exp_op<T>> {}; 742 743 template <typename T> 744 struct expm1 : base<T, Eigen::internal::scalar_expm1_op<T>> {}; 745 746 template <typename T> 747 struct log : base<T, Eigen::internal::scalar_log_op<T>> {}; 748 749 template <typename T> 750 struct log1p : base<T, Eigen::internal::scalar_log1p_op<T>> {}; 751 752 template <typename T> 753 struct sign : base<T, Eigen::internal::scalar_sign_op<T>> {}; 754 755 template <typename T> 756 struct sinh : base<T, Eigen::internal::scalar_sinh_op<T>> {}; 757 758 template <typename T> 759 struct cosh : base<T, Eigen::internal::scalar_cosh_op<T>> {}; 760 761 template <typename T> 762 struct tanh : base<T, Eigen::internal::scalar_tanh_op<T>> {}; 763 764 template <typename T> 765 struct asinh : base<T, Eigen::internal::scalar_asinh_op<T>> {}; 766 767 template <typename T> 768 struct acosh : base<T, Eigen::internal::scalar_acosh_op<T>> {}; 769 770 template <typename T> 771 struct atanh : base<T, Eigen::internal::scalar_atanh_op<T>> {}; 772 773 template <typename T> 774 struct lgamma : base<T, Eigen::internal::scalar_lgamma_op<T>> {}; 775 776 template <typename T> 777 struct digamma : base<T, Eigen::internal::scalar_digamma_op<T>> {}; 778 779 template <typename T> 780 struct erf : base<T, Eigen::internal::scalar_erf_op<T>> {}; 781 782 template <typename T> 783 struct erfc : base<T, Eigen::internal::scalar_erfc_op<T>> {}; 784 785 template <typename T> 786 struct sigmoid : base<T, Eigen::internal::scalar_logistic_op<T>> {}; 787 788 template <typename T> 789 struct sin : base<T, Eigen::internal::scalar_sin_op<T>> {}; 790 791 template <typename T> 792 struct cos : base<T, Eigen::internal::scalar_cos_op<T>> {}; 793 794 template <typename T> 795 struct tan : base<T, Eigen::internal::scalar_tan_op<T>> {}; 796 797 template <typename T> 798 struct asin : base<T, Eigen::internal::scalar_asin_op<T>> {}; 799 800 template <typename T> 801 struct acos : base<T, Eigen::internal::scalar_acos_op<T>> {}; 802 803 template <typename T> 804 struct atan : base<T, Eigen::internal::scalar_atan_op<T>> {}; 805 806 template <typename T> 807 struct bessel_i0e : base<T, Eigen::internal::scalar_i0e_op<T>> {}; 808 809 template <typename T> 810 struct bessel_i1e : base<T, Eigen::internal::scalar_i1e_op<T>> {}; 811 812 struct logical_not : base<bool, Eigen::internal::scalar_boolean_not_op<bool>> { 813 }; 814 815 // Flip all bits. Named invert to be consistent with numpy. 816 template <typename T> 817 struct invert_op { 818 EIGEN_EMPTY_STRUCT_CTOR(invert_op) 819 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a) const { 820 return ~a; 821 } 822 }; 823 824 template <typename T> 825 struct invert : base<T, invert_op<T>> {}; 826 827 // NOTE: std::isinf, std::isnan, std::isfinite are plain function. 828 // Therefore we need to wrap them in functors to be used with Eigen's 829 // type system. 830 template <typename T> 831 struct isinf : base<T, Eigen::internal::scalar_isinf_op<T>, bool> {}; 832 833 template <typename T> 834 struct isnan : base<T, Eigen::internal::scalar_isnan_op<T>, bool> {}; 835 836 template <typename T> 837 struct isfinite : base<T, Eigen::internal::scalar_isfinite_op<T>, bool> {}; 838 839 template <typename T> 840 struct floor : base<T, Eigen::internal::scalar_floor_op<T>> {}; 841 842 template <typename T> 843 struct round : base<T, Eigen::internal::scalar_round_op_google<T>> {}; 844 845 template <typename T> 846 struct ceil : base<T, Eigen::internal::scalar_ceil_op<T>> {}; 847 848 /** this should go in Eigen 849 * \brief Template functor to compute the round to int value of a scalar 850 */ 851 template <typename Scalar> 852 struct scalar_rint_op { 853 EIGEN_EMPTY_STRUCT_CTOR(scalar_rint_op) 854 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar 855 operator()(const Scalar& a) const { 856 #if defined(__CUDACC__) 857 return ::rint(a); 858 #elif defined(__ANDROID__) 859 return rint(a); 860 #else 861 return std::rint(a); 862 #endif 863 } 864 }; 865 866 template <typename T> 867 struct rint : base<T, scalar_rint_op<T>> {}; 868 869 //////////////////////////////////////////////////////////////////////////////// 870 // Binary functors 871 //////////////////////////////////////////////////////////////////////////////// 872 873 // Binary functors: 874 // 875 // add(x, y) = x + y 876 // sub(x, y) = x - y 877 // mul(x, y) = x * y 878 // div(x, y) = x / y 879 // mod(x, y) = x % y (int32 and int64 only) 880 // fmod(x, y) = fmod(x, y) (float and double only) 881 // pow(x, y) = x ^ y 882 // maximum(x, y) = x > y ? x : y 883 // minimum(x, y) = x < y ? x : y 884 // squared_difference(x, y) = conj(x - y) * (x - y) 885 886 template <typename T> 887 struct add : base<T, Eigen::internal::scalar_sum_op<T>> { 888 static const bool use_bcast_optimization = true; 889 }; 890 891 template <typename T> 892 struct sub : base<T, Eigen::internal::scalar_difference_op<T>> { 893 static const bool use_bcast_optimization = true; 894 }; 895 896 template <typename T> 897 struct mul : base<T, Eigen::internal::scalar_product_op<T>> { 898 static const bool use_bcast_optimization = true; 899 }; 900 901 template <typename T> 902 struct mul_no_nan : base<T, Eigen::internal::mul_no_nan_op<T>> {}; 903 904 template <typename T> 905 struct div : base<T, Eigen::internal::scalar_quotient_op<T>> {}; 906 907 template <typename T> 908 struct safe_div : base<T, Eigen::internal::safe_div_or_mod_op< 909 T, Eigen::internal::scalar_quotient_op<T>>> { 910 static const bool has_errors = true; 911 }; 912 913 template <typename T> 914 struct div_no_nan : base<T, Eigen::internal::div_no_nan_op<T>> {}; 915 916 template <typename T> 917 struct fmod : base<T, Eigen::internal::scalar_fmod_op<T>> {}; 918 919 template <typename T> 920 struct mod : base<T, Eigen::internal::scalar_mod2_op<T>> {}; 921 922 template <typename T> 923 struct safe_mod : base<T, Eigen::internal::safe_div_or_mod_op< 924 T, Eigen::internal::scalar_mod2_op<T>>> { 925 static const bool has_errors = true; 926 }; 927 928 template <typename T> 929 struct floor_fmod : base<T, Eigen::internal::google_floor_fmod<T>> {}; 930 931 template <typename T> 932 struct safe_floor_mod : base<T, Eigen::internal::safe_div_or_mod_op< 933 T, Eigen::internal::google_floor_mod<T>>> { 934 static const bool has_errors = true; 935 }; 936 937 template <typename T> 938 struct floor_div : base<T, Eigen::internal::google_floor_div<T>> {}; 939 940 template <typename T> 941 struct safe_floor_div : base<T, Eigen::internal::safe_div_or_mod_op< 942 T, Eigen::internal::google_floor_div<T>>> { 943 static const bool has_errors = true; 944 }; 945 946 template <typename T> 947 struct floor_div_real : base<T, Eigen::internal::google_floor_div_real<T>> {}; 948 949 template <typename T> 950 struct pow : base<T, Eigen::internal::scalar_pow_op<T, T>> {}; 951 952 template <typename T> 953 struct safe_pow : base<T, Eigen::internal::safe_scalar_binary_pow_op<T, T>> { 954 static const bool has_errors = true; 955 }; 956 957 template <typename T> 958 struct maximum : base<T, Eigen::internal::scalar_max_op<T>> {}; 959 960 template <typename T> 961 struct minimum : base<T, Eigen::internal::scalar_min_op<T>> {}; 962 963 template <typename T> 964 struct igamma : base<T, Eigen::internal::scalar_igamma_op<T>> {}; 965 966 template <typename T> 967 struct random_gamma_grad 968 : base<T, Eigen::internal::scalar_gamma_sample_der_alpha_op<T>> {}; 969 970 template <typename T> 971 struct igammac : base<T, Eigen::internal::scalar_igammac_op<T>> {}; 972 973 template <typename T> 974 struct zeta : base<T, Eigen::internal::scalar_zeta_op<T>> {}; 975 976 template <typename T> 977 struct polygamma : base<T, Eigen::internal::scalar_polygamma_op<T>> {}; 978 979 template <typename Scalar> 980 struct scalar_atan2_op { 981 EIGEN_EMPTY_STRUCT_CTOR(scalar_atan2_op) 982 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar 983 operator()(const Scalar& y, const Scalar& x) const { 984 #if GOOGLE_CUDA 985 return ::atan2(y, x); 986 #else 987 return std::atan2(y, x); 988 #endif 989 } 990 }; 991 992 template <typename T> 993 struct atan2 : base<T, scalar_atan2_op<T>> {}; 994 995 template <typename T> 996 struct squared_difference 997 : base<T, Eigen::internal::scalar_squared_difference_op<T>> {}; 998 999 template <typename T> 1000 struct xdivy : base<T, Eigen::internal::xdivy_op<T>> {}; 1001 1002 template <typename T> 1003 struct xlogy : base<T, Eigen::internal::xlogy_op<T>> {}; 1004 1005 template <typename T> 1006 struct less : base<T, Eigen::internal::less<T>, bool> {}; 1007 1008 template <typename T> 1009 struct less_equal : base<T, Eigen::internal::less_equal<T>, bool> {}; 1010 1011 template <typename T> 1012 struct greater : base<T, Eigen::internal::greater<T>, bool> {}; 1013 1014 template <typename T> 1015 struct greater_equal : base<T, Eigen::internal::greater_equal<T>, bool> {}; 1016 1017 template <typename T> 1018 struct equal_to : base<T, Eigen::internal::equal_to<T>, bool> {}; 1019 1020 template <typename T> 1021 struct not_equal_to : base<T, Eigen::internal::not_equal_to<T>, bool> {}; 1022 1023 struct logical_and : base<bool, Eigen::internal::scalar_boolean_and_op> {}; 1024 1025 struct logical_or : base<bool, Eigen::internal::scalar_boolean_or_op> {}; 1026 1027 template <typename T> 1028 struct bitwise_and_op { 1029 EIGEN_EMPTY_STRUCT_CTOR(bitwise_and_op) 1030 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, 1031 const T& y) const { 1032 return x & y; 1033 } 1034 }; 1035 1036 template <typename T> 1037 struct bitwise_or_op { 1038 EIGEN_EMPTY_STRUCT_CTOR(bitwise_or_op) 1039 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, 1040 const T& y) const { 1041 return x | y; 1042 } 1043 }; 1044 1045 template <typename T> 1046 struct bitwise_and : base<T, bitwise_and_op<T>> {}; 1047 1048 template <typename T> 1049 struct bitwise_or : base<T, bitwise_or_op<T>> {}; 1050 1051 template <typename T> 1052 struct bitwise_xor : base<T, Eigen::internal::bitwise_xor_op<T>> {}; 1053 1054 template <typename T> 1055 struct left_shift_op { 1056 EIGEN_EMPTY_STRUCT_CTOR(left_shift_op) 1057 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, 1058 const T& y) const { 1059 // Avoids UB: don't shift by larger than the bitwidth of T, and 1060 // performs left shifts as unsigned shifts. 1061 T y_clamped = y; 1062 if (y_clamped < 0) { 1063 y_clamped = 0; 1064 } else if (y_clamped > sizeof(T) * CHAR_BIT - 1) { 1065 y_clamped = sizeof(T) * CHAR_BIT - 1; 1066 } 1067 using U = typename std::make_unsigned<T>::type; 1068 return static_cast<T>(static_cast<U>(x) << static_cast<U>(y_clamped)); 1069 } 1070 }; 1071 1072 template <typename T> 1073 struct right_shift_op { 1074 EIGEN_EMPTY_STRUCT_CTOR(right_shift_op) 1075 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, 1076 const T& y) const { 1077 // Avoids UB: don't shift by larger than the bitwidth of T. 1078 T y_clamped = y; 1079 if (y_clamped < 0) { 1080 y_clamped = 0; 1081 } else if (y_clamped > sizeof(T) * CHAR_BIT - 1) { 1082 y_clamped = sizeof(T) * CHAR_BIT - 1; 1083 } 1084 // Technically right shifts of signed integers are not necessarily 1085 // arithmetic shifts according to the C++ standard. However in practice most 1086 // implementations are arithmetic shifts. If this proves to be a problem in 1087 // practice, we may need to use an alternative implementation. 1088 return x >> y_clamped; 1089 } 1090 }; 1091 1092 template <typename T> 1093 struct left_shift : base<T, left_shift_op<T>> {}; 1094 1095 template <typename T> 1096 struct right_shift : base<T, right_shift_op<T>> {}; 1097 1098 template <typename T> 1099 struct make_complex_func { 1100 typedef std::complex<T> result_type; 1101 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(T real, 1102 T imag) const { 1103 return std::complex<T>(real, imag); 1104 } 1105 }; 1106 1107 template <typename T> 1108 struct make_complex : base<T, make_complex_func<T>, std::complex<T>> {}; 1109 1110 template <typename T> 1111 struct get_real 1112 : base<T, Eigen::internal::scalar_real_op<T>, typename T::value_type> {}; 1113 1114 template <typename T> 1115 struct get_imag 1116 : base<T, Eigen::internal::scalar_imag_op<T>, typename T::value_type> {}; 1117 1118 template <typename T> 1119 struct get_angle 1120 : base<T, Eigen::internal::scalar_arg_op<T>, typename T::value_type> {}; 1121 1122 template <typename T> 1123 struct conj : base<T, Eigen::internal::scalar_conjugate_op<T>> {}; 1124 1125 //////////////////////////////////////////////////////////////////////////////// 1126 // Functors takes 1 or 2 tensors, computes the base functor on 1127 // coefficient of the input tensors and puts the results in the output 1128 // tensor. 1129 //////////////////////////////////////////////////////////////////////////////// 1130 template <typename Device, typename Functor> 1131 struct UnaryFunctor { 1132 // Computes on device "d": out[i] = Functor(in[i]) 1133 void operator()(const Device& d, typename Functor::tout_type out, 1134 typename Functor::tin_type in); 1135 }; 1136 1137 template <typename Device, typename Functor, int NDIMS, 1138 bool has_errors = Functor::has_errors> 1139 struct BinaryFunctor { 1140 // Computes on device "d": out[i] = Functor(in0[i], in1[i]) 1141 void operator()(const Device& d, typename Functor::tout_type out, 1142 typename Functor::tin_type in0, 1143 typename Functor::tin_type in1, bool* error); 1144 1145 // Computes on device "d": out[i] = Functor(scalar[0], in[i]) 1146 void Left(const Device& d, typename Functor::tout_type out, 1147 typename Functor::tscalar_type scalar, 1148 typename Functor::tin_type in, bool* error); 1149 1150 // Computes on device "d": out[i] = Functor(in[i], scalar[0]) 1151 void Right(const Device& d, typename Functor::tout_type out, 1152 typename Functor::tin_type in, 1153 typename Functor::tscalar_type scalar, bool* error); 1154 1155 // Computes on device "d": 1156 // out = Functor(in0.broadcast(bcast0), in1.broadcast(bcast1)) 1157 // 1158 // TODO(zhifengc): makes BCast a template member function on NDIMS 1159 // instead making BinaryFunctor templates on NDIMS. 1160 void BCast(const Device& d, 1161 typename TTypes<typename Functor::out_type, NDIMS>::Tensor out, 1162 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0, 1163 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0, 1164 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1, 1165 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1, 1166 bool* error); 1167 }; 1168 1169 template <typename Device, typename T> 1170 struct ApproximateEqual { 1171 void operator()(const Device& d, typename TTypes<T>::ConstFlat x, 1172 typename TTypes<T>::ConstFlat y, T tolerance, 1173 typename TTypes<bool>::Flat z); 1174 }; 1175 1176 template <int NDIMS> 1177 bool AllOne(const typename Eigen::array<Eigen::DenseIndex, NDIMS>& a) { 1178 for (size_t i = 0; i < a.size(); ++i) { 1179 if (a[i] != 1) return false; 1180 } 1181 return true; 1182 } 1183 1184 template <typename Device, typename T> 1185 struct SelectFunctor { 1186 void operator()(const Device& d, typename TTypes<T>::Flat out, 1187 typename TTypes<bool>::ConstFlat cond_flat, 1188 typename TTypes<T>::ConstFlat then_flat, 1189 typename TTypes<T>::ConstFlat else_flat); 1190 }; 1191 1192 template <typename Device, typename T> 1193 struct SelectScalarFunctor { 1194 void operator()(const Device& d, typename TTypes<T>::Flat out, 1195 typename TTypes<bool>::ConstScalar cond, 1196 typename TTypes<T>::ConstFlat then_flat, 1197 typename TTypes<T>::ConstFlat else_flat); 1198 }; 1199 1200 template <typename Device, typename T> 1201 struct BatchSelectFunctor { 1202 void operator()(const Device& d, 1203 typename TTypes<T>::Matrix output_flat_outer_dims, 1204 TTypes<bool>::ConstVec cond_vec, 1205 typename TTypes<T>::ConstMatrix then_flat_outer_dims, 1206 typename TTypes<T>::ConstMatrix else_flat_outer_dims); 1207 }; 1208 1209 } // end namespace functor 1210 } // end namespace tensorflow 1211 1212 #endif // TENSORFLOW_CORE_KERNELS_CWISE_OPS_H_ 1213