1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_REDUCTION_GPU_KERNELS_CU_H_ 17 #define TENSORFLOW_CORE_KERNELS_REDUCTION_GPU_KERNELS_CU_H_ 18 19 #if GOOGLE_CUDA 20 21 #define EIGEN_USE_GPU 22 23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 24 #include "third_party/cub/device/device_reduce.cuh" 25 #include "third_party/cub/device/device_segmented_reduce.cuh" 26 #include "third_party/cub/iterator/counting_input_iterator.cuh" 27 #include "third_party/cub/iterator/transform_input_iterator.cuh" 28 #include "third_party/cub/warp/warp_reduce.cuh" 29 #include "cuda/include/cuComplex.h" 30 #include "tensorflow/core/kernels/reduction_ops.h" 31 #include "tensorflow/core/lib/core/bits.h" 32 #include "tensorflow/core/util/cuda_kernel_helper.h" 33 #include "tensorflow/core/util/permutation_input_iterator.h" 34 #include "tensorflow/core/util/transform_output_iterator.h" 35 36 #include <sstream> 37 38 namespace tensorflow { 39 namespace functor { 40 41 typedef Eigen::GpuDevice GPUDevice; 42 43 template <typename T> 44 struct Sqrt { operatorSqrt45 __host__ __device__ T operator()(const T& a) const { 46 return Eigen::numext::sqrt(a); 47 } 48 }; 49 50 template <typename T> 51 struct Sum { operatorSum52 __host__ __device__ T operator()(const T& a, const T& b) const { 53 return a + b; 54 } 55 }; 56 57 // needed to work around a compiler bug in nvcc - it doesn't seem to like 58 // the overloaded addition op for std::complex 59 template <> 60 struct Sum<std::complex<float>> { 61 __host__ __device__ std::complex<float> operator()( 62 const std::complex<float>& a, const std::complex<float>& b) const { 63 auto result = cuCaddf(make_cuComplex(a.real(), a.imag()), 64 make_cuComplex(b.real(), b.imag())); 65 return std::complex<float>(result.x, result.y); 66 } 67 }; 68 69 template <> 70 struct Sum<std::complex<double>> { 71 __host__ __device__ std::complex<double> operator()( 72 const std::complex<double>& a, const std::complex<double>& b) const { 73 auto result = cuCadd(make_cuDoubleComplex(a.real(), a.imag()), 74 make_cuDoubleComplex(b.real(), b.imag())); 75 return std::complex<double>(result.x, result.y); 76 } 77 }; 78 79 template <typename T> 80 struct Prod { 81 __host__ __device__ T operator()(const T& a, const T& b) const { 82 return a * b; 83 } 84 }; 85 86 // needed to work around a compiler bug in nvcc - it doesn't seem to like 87 // the overloaded multiply op for std::complex 88 template <> 89 struct Prod<std::complex<float>> { 90 __host__ __device__ std::complex<float> operator()( 91 const std::complex<float>& a, const std::complex<float>& b) const { 92 auto result = cuCmulf(make_cuComplex(a.real(), a.imag()), 93 make_cuComplex(b.real(), b.imag())); 94 return std::complex<float>(result.x, result.y); 95 } 96 }; 97 98 template <> 99 struct Prod<std::complex<double>> { 100 __host__ __device__ std::complex<double> operator()( 101 const std::complex<double>& a, const std::complex<double>& b) const { 102 auto result = cuCmul(make_cuDoubleComplex(a.real(), a.imag()), 103 make_cuDoubleComplex(b.real(), b.imag())); 104 return std::complex<double>(result.x, result.y); 105 } 106 }; 107 108 template <typename T> 109 struct Square { 110 __host__ __device__ T operator()(const T& a) const { 111 return Prod<T>()(a, Eigen::numext::conj(a)); 112 } 113 }; 114 115 template <typename T, typename outT = T> 116 struct DividesBy { 117 T divisor; 118 119 __host__ __device__ explicit DividesBy(T divisor) : divisor(divisor) {} 120 121 __host__ __device__ outT operator()(const T& x) const { return x / divisor; } 122 }; 123 124 // needed to work around a compiler bug in nvcc - it doesn't seem to like 125 // the overloaded ops for std::complex 126 template <> 127 struct DividesBy<std::complex<float>> { 128 cuFloatComplex divisor; 129 130 __host__ __device__ explicit DividesBy(std::complex<float> divisor) 131 : divisor(make_cuComplex(divisor.real(), divisor.imag())) {} 132 133 // implements 134 __host__ __device__ std::complex<float> operator()( 135 const std::complex<float>& x) const { 136 auto result = cuCdivf(make_cuComplex(x.real(), x.imag()), divisor); 137 return std::complex<float>(result.x, result.y); 138 } 139 }; 140 141 template <> 142 struct DividesBy<std::complex<double>> { 143 cuDoubleComplex divisor; 144 145 __host__ __device__ explicit DividesBy(std::complex<double> divisor) 146 : divisor(make_cuDoubleComplex(divisor.real(), divisor.imag())) {} 147 148 // implements 149 __host__ __device__ std::complex<double> operator()( 150 const std::complex<double>& x) const { 151 auto result = cuCdiv(make_cuDoubleComplex(x.real(), x.imag()), divisor); 152 return std::complex<double>(result.x, result.y); 153 } 154 }; 155 156 template <> 157 struct DividesBy<float, Eigen::half> { 158 float divisor; 159 160 __host__ __device__ explicit DividesBy(float divisor) : divisor(divisor) {} 161 162 __host__ __device__ Eigen::half operator()(const float& x) const { 163 return Eigen::half(x / divisor); 164 } 165 }; 166 167 struct HalfToFloat { 168 __host__ __device__ float operator()(const Eigen::half& x) const { 169 return Eigen::half_impl::half_to_float(x); 170 } 171 }; 172 173 struct FloatToHalf { 174 __host__ __device__ Eigen::half operator()(const float& x) const { 175 return Eigen::half_impl::float_to_half_rtne(x); 176 } 177 }; 178 179 struct And { 180 __host__ __device__ bool operator()(const bool& a, const bool& b) const { 181 return a && b; 182 } 183 }; 184 185 struct Or { 186 __host__ __device__ bool operator()(const bool& a, const bool& b) const { 187 return a || b; 188 } 189 }; 190 191 // each block does a grid strided loop and reduces its values locally 192 // the case of one block is used for low latency small reductions to scalars 193 template <typename T, typename outT, int num_threads, typename Op> 194 __global__ void BlockReduceKernel( 195 T in, outT out, int num_elems, Op op, 196 typename std::iterator_traits<T>::value_type initVal) { 197 const int bid = blockIdx.x; 198 const int tid = threadIdx.x; 199 200 const int gid = bid * blockDim.x + tid; 201 const int stride = blockDim.x * gridDim.x; 202 203 typedef typename std::iterator_traits<T>::value_type value_type; 204 205 value_type sum = initVal; 206 if (gid < num_elems) { 207 sum = in[gid]; 208 for (int pos = gid + stride; pos < num_elems; pos += stride) { 209 sum = op(sum, in[pos]); 210 } 211 } 212 213 typedef cub::BlockReduce<value_type, num_threads> BlockReduce; 214 215 __shared__ typename BlockReduce::TempStorage temp_storage; 216 217 // only include input values in the reduction 218 // 219 // elements: ----------------- 220 // grid: |====|====|====|====|====| 221 const int num_elements_to_reduce = 222 max(min(num_elems - bid * blockDim.x, num_threads), 0); 223 224 sum = BlockReduce(temp_storage).Reduce(sum, op, num_elements_to_reduce); 225 226 if (tid == 0) out[bid] = sum; 227 } 228 229 // maps a warp to each row 230 template <typename T, typename outT, typename Op> 231 __global__ void RowReduceKernel( 232 T in, outT out, int num_rows, int num_cols, Op op, 233 typename std::iterator_traits<T>::value_type initVal) { 234 typedef typename std::iterator_traits<T>::value_type value_type; 235 // Defensive index computation to avoid integer overflow. 236 assert(blockDim.x % 32 == 0); 237 int warps_per_block = blockDim.x / 32; 238 int warp_index = threadIdx.x / 32; 239 const int row = blockIdx.x * warps_per_block + warp_index; 240 const int lane = threadIdx.x % 32; 241 242 if (num_cols == 1) { 243 int gid = threadIdx.x + blockIdx.x * blockDim.x; 244 if (gid < num_rows) out[gid] = in[gid]; 245 return; 246 } 247 248 value_type sum = initVal; 249 int col = lane; 250 251 if (row < num_rows && col < num_cols) { 252 sum = in[row * num_cols + col]; 253 col += 32; 254 for (; col < num_cols; col += 32) { 255 sum = op(sum, in[row * num_cols + col]); 256 } 257 } 258 259 typedef cub::WarpReduce<value_type> WarpReduce; 260 261 __shared__ typename WarpReduce::TempStorage temp_storage; 262 263 sum = WarpReduce(temp_storage).Reduce(sum, op, min(num_cols, 32)); 264 265 if (row < num_rows && lane == 0) out[row] = sum; 266 } 267 268 template <typename T1> 269 struct storage_type { 270 T1 val; 271 __host__ __device__ storage_type() {} 272 __host__ __device__ operator T1() { return val; } 273 __host__ __device__ storage_type<T1>& operator=(const T1& in) { 274 val = in; 275 return *this; 276 } 277 }; 278 279 template <typename T2> 280 struct storage_type<std::complex<T2>> { 281 T2 real; 282 T2 imag; 283 __host__ __device__ storage_type() {} 284 __host__ __device__ operator std::complex<T2>() { 285 return std::complex<T2>(real, imag); 286 } 287 __host__ __device__ storage_type<std::complex<T2>>& operator=( 288 const std::complex<T2>& in) { 289 real = in.real(); 290 imag = in.imag(); 291 return *this; 292 } 293 }; 294 295 // Works only if there are <= 16 columns 296 // each warps sums over multiple rows at once 297 template <typename T, typename outT, typename Op> 298 __global__ void ColumnReduceMax16ColumnsKernel( 299 T in, outT out, int num_rows, int num_cols, Op op, 300 typename std::iterator_traits<T>::value_type initVal) { 301 typedef typename std::iterator_traits<T>::value_type value_type; 302 int rows_per_warp = 32 / num_cols; 303 304 const int lane = threadIdx.x % 32; 305 const int lane_row = lane / num_cols; 306 307 const int start_row_warp = 308 rows_per_warp * (blockIdx.y * blockDim.y + threadIdx.y); 309 const int start_row_lane = start_row_warp + lane_row; 310 int row = start_row_lane; 311 int col = lane % num_cols; 312 313 value_type sum = initVal; 314 if (row * num_cols + col < num_rows * num_cols) 315 sum = in[row * num_cols + col]; 316 317 // 1D array necessary due to bug in CUDA 9 compiler. 318 // TODO(nluehr) revert to 2D array when compiler is ready. 319 // This is to mimic the following, but without any constructors: 320 // __shared__ storage_type<value_type> partial_sums[32 * 33]; 321 __shared__ __align__( 322 alignof(value_type)) char partial_sums_raw[32 * 33 * sizeof(value_type)]; 323 value_type* partial_sums = reinterpret_cast<value_type*>(partial_sums_raw); 324 325 row += rows_per_warp * gridDim.y * blockDim.y; 326 for (; row < num_rows; row += rows_per_warp * gridDim.y * blockDim.y) { 327 int global_pos = row * num_cols + col; 328 if (global_pos < (num_rows * num_cols)) 329 sum = op(sum, in[row * num_cols + col]); 330 } 331 332 const int rows_in_this_warp = min(rows_per_warp, num_rows - start_row_warp); 333 // not the most efficient way to do this sum 334 for (int i = 1; i < rows_in_this_warp; ++i) { 335 value_type tmp = cub::ShuffleIndex<32, value_type>( 336 sum, static_cast<int>(threadIdx.x + i * num_cols), 0xffffffff); 337 if (lane < num_cols) sum = op(sum, tmp); 338 } 339 340 if (lane < num_cols) partial_sums[lane * 33 + threadIdx.y] = sum; 341 342 __syncthreads(); 343 344 if (threadIdx.y == 0 && threadIdx.x < num_cols) { 345 value_type s = partial_sums[threadIdx.x * 33]; 346 347 if (blockDim.y > 1) { 348 for (int row = 1; row < blockDim.y; ++row) { 349 value_type t = partial_sums[threadIdx.x * 33 + row]; 350 s = op(s, t); 351 } 352 } 353 354 out[col * gridDim.y + blockIdx.y] = s; 355 } 356 } 357 358 // Maps each block to a column range 32 wide 359 template <typename T, typename outT, typename Op> 360 __global__ void ColumnReduceKernel( 361 T in, outT out, int num_rows, int num_cols, Op op, 362 typename std::iterator_traits<T>::value_type initVal) { 363 typedef typename std::iterator_traits<T>::value_type value_type; 364 int row = blockIdx.y * blockDim.y + threadIdx.y; 365 int col = blockIdx.x * 32 + threadIdx.x; 366 367 value_type sum = initVal; 368 if (row < num_rows && col < num_cols) sum = in[row * num_cols + col]; 369 370 // 1D array necessary due to bug in CUDA 9 compiler. 371 // TODO(nluehr) revert to 2D array when compiler is ready. 372 // This is to mimic the following, but without constructors: 373 // __shared__ storage_type<value_type> partial_sums[32 * 33]; 374 __shared__ __align__( 375 alignof(value_type)) char partial_sums_raw[32 * 33 * sizeof(value_type)]; 376 value_type* partial_sums = reinterpret_cast<value_type*>(partial_sums_raw); 377 378 row += gridDim.y * blockDim.y; 379 380 if (col < num_cols) { 381 for (; row < num_rows; row += gridDim.y * blockDim.y) { 382 sum = op(sum, in[row * num_cols + col]); 383 } 384 } 385 386 partial_sums[threadIdx.x * 33 + threadIdx.y] = sum; 387 388 __syncthreads(); 389 390 if (threadIdx.y == 0 && col < num_cols) { 391 value_type s = partial_sums[threadIdx.x * 33]; 392 393 // only include input values in the reduction 394 // elem block_rows 395 // - = 396 // - = 397 // # # block boundary 398 // - = 399 // - = 400 // # # block boundary 401 // - = 402 // = 403 const int numRowsThisBlock = 404 min(blockDim.y, num_rows - blockIdx.y * blockDim.y); 405 406 for (int row = 1; row < numRowsThisBlock; ++row) { 407 value_type t = partial_sums[threadIdx.x * 33 + row]; 408 s = op(s, t); 409 } 410 411 out[col * gridDim.y + blockIdx.y] = s; 412 } 413 } 414 415 // does multiple warp size segmented reductions in parallel 416 // segments cannot cross warp boundaries (mainly used for reducing the segments 417 // that come from the Max16Columns column reduction kernel) 418 template <typename T, typename outT, typename Op> 419 __global__ void CleanupSegments( 420 T partial_sums, outT out, int num_rows, int num_cols, int segment_size, 421 Op op, typename std::iterator_traits<T>::value_type initVal) { 422 typedef typename std::iterator_traits<T>::value_type value_type; 423 const int tid = threadIdx.x + blockIdx.x * blockDim.x; 424 425 value_type val = initVal; 426 if (tid < segment_size * num_cols) val = partial_sums[tid]; 427 428 typedef cub::WarpReduce<value_type> WarpReduce; 429 430 __shared__ typename WarpReduce::TempStorage temp_storage; 431 432 const bool head_flag = (threadIdx.x % segment_size) == 0; 433 value_type sum = 434 WarpReduce(temp_storage).HeadSegmentedReduce(val, head_flag, op); 435 436 if (head_flag && tid < segment_size * num_cols) { 437 out[tid / segment_size] = sum; 438 } 439 } 440 441 // assigns one thread to a column 442 template <typename T, typename outT, typename Op> 443 __global__ void ColumnReduceSimpleKernel(T in, outT out, int num_planes, 444 int num_rows, int num_cols, Op op) { 445 typedef typename std::iterator_traits<T>::value_type value_type; 446 const int gid = threadIdx.x + blockIdx.x * blockDim.x; 447 const int elems_per_plane = num_rows * num_cols; 448 449 const int plane = gid / num_cols; 450 const int col = gid % num_cols; 451 452 if (plane >= num_planes) return; 453 454 if (num_rows == 1) { 455 out[plane * elems_per_plane + col] = in[plane * elems_per_plane + col]; 456 return; 457 } 458 459 value_type sum = op(in[plane * elems_per_plane + col], 460 in[plane * elems_per_plane + num_cols + col]); 461 for (int row = 2; row < num_rows; ++row) { 462 sum = op(sum, in[plane * elems_per_plane + row * num_cols + col]); 463 } 464 465 out[plane * num_cols + col] = sum; 466 } 467 468 struct RowOffset { 469 __host__ __device__ explicit RowOffset(const int& cols) : cols_(cols) {} 470 471 __host__ __device__ int operator()(const int& x) const { return cols_ * x; } 472 473 int cols_; 474 }; 475 476 struct GatherOp { 477 __host__ __device__ GatherOp(const int& extent_x, const int& extent_y, 478 const int& extent_z, bool kOne) 479 : extent_x_(extent_x), 480 extent_y_(extent_y), 481 extent_z_(extent_z), 482 kOne_(kOne) { 483 if (kOne_) 484 group_size_ = extent_y_; 485 else 486 group_size_ = extent_x_ * extent_z_; 487 } 488 489 __host__ __device__ int operator()(const int& ind) const { 490 const int group = kOne_ ? ind / group_size_ : ind % group_size_; 491 const int offset = kOne_ ? ind % group_size_ : ind / group_size_; 492 493 const int x = group / extent_z_; 494 const int z = group % extent_z_; 495 496 return x * extent_y_ * extent_z_ + z + offset * extent_z_; 497 } 498 499 int extent_x_; 500 int extent_y_; 501 int extent_z_; 502 bool kOne_; 503 int group_size_; 504 }; 505 506 template <typename T, typename Op, typename OUT_T, typename IN_T> 507 void LaunchScalarReduction(OpKernelContext* ctx, OUT_T out, IN_T in, 508 int in_size, Op op, T init, 509 const cudaStream_t& cu_stream) { 510 // handle situations where low latency is important better than CUB 511 if (in_size <= 4096) { 512 const int num_blocks = 1; 513 const int num_threads = 256; 514 TF_CHECK_OK(CudaLaunchKernel( 515 BlockReduceKernel<IN_T, OUT_T, num_threads, Op>, num_blocks, 516 num_threads, 0, cu_stream, in, out, in_size, op, init)); 517 return; 518 } else if (in_size <= 1 << 18) { 519 const int num_threads = 256; 520 const int num_blocks = std::min(32, Eigen::divup(in_size, num_threads)); 521 // it seems like tailoring this to the GPU 522 // would be more effective, but all attempts 523 // at making this a multiple of the number of 524 // multiprocessors have lead to lower perf 525 // in general 526 // TODO(eriche) investigate this more 527 528 Tensor temp_storage; 529 OP_REQUIRES_OK( 530 ctx, 531 ctx->allocate_temp( 532 DT_INT8, TensorShape({static_cast<int64>(num_blocks * sizeof(T))}), 533 &temp_storage)); 534 535 TF_CHECK_OK(CudaLaunchKernel(BlockReduceKernel<IN_T, T*, num_threads, Op>, 536 num_blocks, num_threads, 0, cu_stream, in, 537 (T*)temp_storage.flat<int8_t>().data(), 538 in_size, op, init)); 539 540 // take care that we only reduce blocks that had some valid elements in them 541 // TODO(eriche): CUB currently has a bug in HeadSegmentedReduce that 542 // requires it to be used with a full warp. Can reduce 32 -> num_blocks 543 // when this is fixed. 544 TF_CHECK_OK(CudaLaunchKernel(CleanupSegments<T*, OUT_T, Op>, 1, 32, 0, 545 cu_stream, 546 (T*)temp_storage.flat<int8_t>().data(), out, 1, 547 1, num_blocks, op, init)); 548 return; 549 } 550 551 size_t temp_storage_bytes = 0; 552 auto reduce = [&](void* temp_storage_ptr) { 553 auto success = 554 cub::DeviceReduce::Reduce(temp_storage_ptr, temp_storage_bytes, in, out, 555 in_size, op, init, cu_stream); 556 557 OP_REQUIRES( 558 ctx, success == 0, 559 errors::Internal("CUB reduce error ", cudaGetErrorString(success))); 560 }; 561 562 reduce(nullptr); // Get required amount of temp storage. 563 564 Tensor temp_storage; 565 OP_REQUIRES_OK( 566 ctx, ctx->allocate_temp( 567 DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}), 568 &temp_storage)); 569 570 reduce(temp_storage.flat<int8_t>().data()); // Do reduction. 571 } 572 573 template <typename T, typename Op, typename OUT_T, typename IN_T> 574 void LaunchRowReduction(OpKernelContext* ctx, OUT_T out, IN_T in, int num_rows, 575 int num_cols, Op op, T init, 576 const cudaStream_t& cu_stream) { 577 if (num_cols < 1024) { 578 const int threads_per_block = 128; 579 const int warps_per_block = threads_per_block / 32; 580 int num_blocks = (num_rows + warps_per_block - 1) / warps_per_block; 581 582 TF_CHECK_OK(CudaLaunchKernel(RowReduceKernel<IN_T, OUT_T, Op>, num_blocks, 583 threads_per_block, 0, cu_stream, in, out, 584 num_rows, num_cols, op, init)); 585 return; 586 } 587 588 // setup segment offsets with counting and transform iterator 589 RowOffset row_offset_op(num_cols); 590 cub::CountingInputIterator<int> counting_iter(0); 591 cub::TransformInputIterator<int, RowOffset, cub::CountingInputIterator<int>> 592 transform_iter(counting_iter, row_offset_op); 593 594 size_t temp_storage_bytes = 0; 595 auto reduce = [&](void* temp_storage_ptr) { 596 auto success = cub::DeviceSegmentedReduce::Reduce( 597 temp_storage_ptr, temp_storage_bytes, in, out, num_rows, transform_iter, 598 transform_iter + 1, op, init, cu_stream); 599 600 OP_REQUIRES(ctx, success == 0, 601 errors::Internal("CUB segmented reduce error", 602 cudaGetErrorString(success))); 603 }; 604 605 reduce(nullptr); // Get required amount of temp storage. 606 607 Tensor temp_storage; 608 OP_REQUIRES_OK( 609 ctx, ctx->allocate_temp( 610 DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}), 611 &temp_storage)); 612 613 reduce(temp_storage.flat<int8_t>().data()); // Do reduction. 614 } 615 616 template <typename T, typename Op, typename OUT_T, typename IN_T> 617 void LaunchColumnReduction_LTE16Cols(OpKernelContext* ctx, OUT_T out, IN_T in, 618 int extent_x, int extent_y, Op op, T init, 619 const cudaStream_t& cu_stream) { 620 int rows_per_warp = 32 / extent_y; 621 dim3 block_dim(32, std::min(Eigen::divup(extent_x, rows_per_warp), 32), 1); 622 dim3 grid_dim(1, 623 Eigen::divup(static_cast<unsigned int>(extent_x), 624 rows_per_warp * block_dim.y), 625 1); 626 627 grid_dim.y = std::min((int)grid_dim.y, 32); 628 629 if (grid_dim.y > 2 && grid_dim.y < 32) { 630 int log2 = Log2Floor(grid_dim.y); 631 grid_dim.y = 1 << log2; 632 } 633 634 if (grid_dim.y == 1) { 635 TF_CHECK_OK(CudaLaunchKernel( 636 ColumnReduceMax16ColumnsKernel<IN_T, OUT_T, Op>, grid_dim, block_dim, 0, 637 cu_stream, in, out, extent_x, extent_y, op, init)); 638 } else { 639 Tensor temp_storage; 640 OP_REQUIRES_OK(ctx, 641 ctx->allocate_temp(DT_INT8, 642 TensorShape({static_cast<int64>( 643 sizeof(T) * extent_y * grid_dim.y)}), 644 &temp_storage)); 645 TF_CHECK_OK(CudaLaunchKernel(ColumnReduceMax16ColumnsKernel<IN_T, T*, Op>, 646 grid_dim, block_dim, 0, cu_stream, in, 647 (T*)temp_storage.flat<int8_t>().data(), 648 extent_x, extent_y, op, init)); 649 650 dim3 new_grid_dim((grid_dim.y * extent_y + 31) / 32, 1, 1); 651 dim3 num_threads(128, 1, 1); 652 TF_CHECK_OK(CudaLaunchKernel(CleanupSegments<T*, OUT_T, Op>, new_grid_dim, 653 num_threads, 0, cu_stream, 654 (T*)temp_storage.flat<int8_t>().data(), out, 655 extent_x, extent_y, grid_dim.y, op, init)); 656 } 657 } 658 659 template <typename T, typename Op, typename OUT_T, typename IN_T> 660 void LaunchColumnReduction_LTE4096Cols(OpKernelContext* ctx, OUT_T out, IN_T in, 661 int extent_x, int extent_y, Op op, 662 T init, const cudaStream_t& cu_stream) { 663 dim3 block_dim(32, std::min(extent_x, 32), 1); 664 dim3 grid_dim((extent_y + 31) / 32, 1, 1); 665 666 if (grid_dim.x < 16) grid_dim.y = std::min((extent_x + 31) / 32, 32); 667 668 if (grid_dim.y > 2 && grid_dim.y < 32) { 669 int log2 = Log2Floor(grid_dim.y); 670 grid_dim.y = 1 << log2; 671 } 672 673 if (grid_dim.y == 1) { 674 TF_CHECK_OK(CudaLaunchKernel(ColumnReduceKernel<IN_T, OUT_T, Op>, grid_dim, 675 block_dim, 0, cu_stream, in, out, extent_x, 676 extent_y, op, init)); 677 } else { 678 Tensor temp_storage; 679 OP_REQUIRES_OK(ctx, 680 ctx->allocate_temp(DT_INT8, 681 TensorShape({static_cast<int64>( 682 sizeof(T) * extent_y * grid_dim.y)}), 683 &temp_storage)); 684 685 TF_CHECK_OK(CudaLaunchKernel( 686 ColumnReduceKernel<IN_T, T*, Op>, grid_dim, block_dim, 0, cu_stream, in, 687 (T*)temp_storage.flat<int8_t>().data(), extent_x, extent_y, op, init)); 688 689 dim3 new_grid_dim((grid_dim.y * extent_y + 31) / 32, 1, 1); 690 dim3 num_threads(128, 1, 1); 691 TF_CHECK_OK(CudaLaunchKernel(CleanupSegments<T*, OUT_T, Op>, new_grid_dim, 692 block_dim, 0, cu_stream, 693 (T*)temp_storage.flat<int8_t>().data(), out, 694 extent_x, extent_y, grid_dim.y, op, init)); 695 } 696 } 697 698 template <typename T, typename Op, typename OUT_T, typename IN_T> 699 void LaunchColumnReduction(OpKernelContext* ctx, OUT_T out, IN_T in, 700 int extent_x, int extent_y, Op op, T init, 701 const cudaStream_t& cu_stream) { 702 if (extent_y <= 16) { 703 LaunchColumnReduction_LTE16Cols(ctx, out, in, extent_x, extent_y, op, init, 704 cu_stream); 705 } else if (extent_y <= 4096) { 706 LaunchColumnReduction_LTE4096Cols(ctx, out, in, extent_x, extent_y, op, 707 init, cu_stream); 708 } else { 709 int threads_per_block = 128; 710 int num_blocks = Eigen::divup(extent_y, threads_per_block); 711 712 TF_CHECK_OK(CudaLaunchKernel(ColumnReduceSimpleKernel<IN_T, OUT_T, Op>, 713 num_blocks, threads_per_block, 0, cu_stream, 714 in, out, 1, extent_x, extent_y, op)); 715 } 716 } 717 718 template <typename T, typename Op, typename OUT_T, typename IN_T> 719 void Launch3DYReduction(OpKernelContext* ctx, OUT_T out, IN_T in, int extent_x, 720 int extent_y, int extent_z, Op op, T init, 721 const cudaStream_t& cu_stream) { 722 int threads_per_block = 128; 723 int num_blocks = 724 (extent_x * extent_z + threads_per_block - 1) / threads_per_block; 725 726 // TODO(eriche): this won't be very good in the case of small x 727 // small z and large y. 728 TF_CHECK_OK(CudaLaunchKernel(ColumnReduceSimpleKernel<IN_T, OUT_T, Op>, 729 num_blocks, threads_per_block, 0, cu_stream, in, 730 out, extent_x, extent_y, extent_z, op)); 731 } 732 733 template <typename T, typename Op, typename OUT_T, typename IN_T> 734 void Launch3DXZReduction(OpKernelContext* ctx, OUT_T out, IN_T in, int extent_x, 735 int extent_y, int extent_z, Op op, T init, 736 const cudaStream_t& cu_stream) { 737 // setup segment offsets with counting and transform iterator 738 RowOffset row_offset_op(extent_x * extent_z); 739 cub::CountingInputIterator<int> counting_iter(0); 740 cub::TransformInputIterator<int, RowOffset, cub::CountingInputIterator<int>> 741 transform_iter(counting_iter, row_offset_op); 742 743 GatherOp gather_op(extent_x, extent_y, extent_z, false); 744 typedef cub::TransformInputIterator<int, GatherOp, 745 cub::CountingInputIterator<int>> 746 gatherIterType; 747 gatherIterType gather_iter(counting_iter, gather_op); 748 749 PermutationInputIterator<T, IN_T, gatherIterType> permute_iter(in, 750 gather_iter); 751 752 std::size_t temp_storage_bytes = 0; 753 auto reduce = [&](void* temp_storage_ptr) { 754 auto success = cub::DeviceSegmentedReduce::Reduce( 755 temp_storage_ptr, temp_storage_bytes, permute_iter, out, extent_y, 756 transform_iter, transform_iter + 1, op, init, cu_stream); 757 758 OP_REQUIRES(ctx, success == 0, 759 errors::Internal("CUB segmented reduce error", 760 cudaGetErrorString(success))); 761 }; 762 763 reduce(nullptr); // Get required amount of temp storage. 764 765 Tensor temp_storage; 766 OP_REQUIRES_OK( 767 ctx, ctx->allocate_temp( 768 DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}), 769 &temp_storage)); 770 771 reduce(temp_storage.flat<int8_t>().data()); // Do reduction. 772 } 773 774 namespace reduction_op_helper { 775 776 template <typename T, typename Op> 777 struct IsSum { 778 constexpr static bool value = 779 (std::is_same<Op, cub::Sum>::value || 780 std::is_same<Op, Eigen::internal::SumReducer<T>>::value || 781 std::is_same<Op, Sum<T>>::value); 782 }; 783 784 template <typename T, typename Op> 785 struct IsMax { 786 constexpr static bool value = 787 (std::is_same<Op, cub::Max>::value || 788 std::is_same<Op, Eigen::internal::MaxReducer<T>>::value); 789 }; 790 791 template <typename T, typename Op> 792 struct IsMin { 793 constexpr static bool value = 794 (std::is_same<Op, cub::Min>::value || 795 std::is_same<Op, Eigen::internal::MinReducer<T>>::value); 796 }; 797 798 template <typename T, typename Op> 799 struct IsProd { 800 constexpr static bool value = 801 (std::is_same<Op, Prod<T>>::value || 802 std::is_same<Op, Eigen::internal::ProdReducer<T>>::value); 803 }; 804 805 template <typename T, typename Op> 806 struct IdentityValue { 807 static_assert(IsSum<T, Op>::value || IsMax<T, Op>::value || 808 IsMin<T, Op>::value || IsProd<T, Op>::value || 809 std::is_same<Op, And>::value || std::is_same<Op, Or>::value, 810 "IdentityValue not yet defined for this type"); 811 812 template <typename U = T, typename OpCopy = Op> 813 U operator()( 814 typename std::enable_if<IsSum<U, OpCopy>::value, U>::type t = U(0)) { 815 return t; 816 } 817 818 template <typename U = T, typename OpCopy = Op> 819 U operator()(typename std::enable_if<IsMax<U, OpCopy>::value, U>::type t = 820 Eigen::NumTraits<U>::lowest()) { 821 return t; 822 } 823 824 template <typename U = T, typename OpCopy = Op> 825 U operator()(typename std::enable_if<IsMin<U, OpCopy>::value, U>::type t = 826 Eigen::NumTraits<U>::highest()) { 827 return t; 828 } 829 830 template <typename U = T, typename OpCopy = Op> 831 U operator()( 832 typename std::enable_if<IsProd<U, OpCopy>::value, U>::type t = U(1)) { 833 return t; 834 } 835 836 template <typename U = T, typename OpCopy = Op> 837 U operator()(typename std::enable_if<std::is_same<OpCopy, And>::value, 838 bool>::type t = true) { 839 return t; 840 } 841 842 template <typename U = T, typename OpCopy = Op> 843 U operator()(typename std::enable_if<std::is_same<OpCopy, Or>::value, 844 bool>::type t = false) { 845 return t; 846 } 847 }; 848 849 } // namespace reduction_op_helper 850 851 template <typename T, typename Op, typename OUT_T, typename IN_T, 852 typename ReductionAxes> 853 void ReduceImpl(OpKernelContext* ctx, OUT_T out, IN_T in, int in_rank, 854 int in_dim0, int in_dim1, int in_dim2, int out_rank, 855 const ReductionAxes& reduction_axes, Op op) { 856 T init = reduction_op_helper::IdentityValue<T, Op>()(); 857 const cudaStream_t& cu_stream = GetCudaStream(ctx); 858 if (out_rank == 0) { 859 const int in_size = in_dim0 * in_dim1 * in_dim2; 860 LaunchScalarReduction(ctx, out, in, in_size, op, init, cu_stream); 861 } else if (in_rank == 2 && out_rank == 1 && 862 reduction_axes[0] == 1) { // row reduction 863 LaunchRowReduction(ctx, out, in, in_dim0, in_dim1, op, init, cu_stream); 864 } else if (in_rank == 2 && out_rank == 1 && 865 reduction_axes[0] == 0) { // column reduction 866 LaunchColumnReduction(ctx, out, in, in_dim0, in_dim1, op, init, cu_stream); 867 } else if (in_rank == 3 && out_rank == 2 && reduction_axes[0] == 1) { 868 Launch3DYReduction(ctx, out, in, in_dim0, in_dim1, in_dim2, op, init, 869 cu_stream); 870 } else if (in_rank == 3 && out_rank == 1 && reduction_axes[0] == 0 && 871 reduction_axes[1] == 2) { 872 Launch3DXZReduction(ctx, out, in, in_dim0, in_dim1, in_dim2, op, init, 873 cu_stream); 874 } else { 875 std::stringstream ss; 876 ss << "Invalid reduction requested: in_rank, out_rank, axes " << in_rank 877 << " " << out_rank; 878 if (out_rank == 1) ss << " " << reduction_axes[0]; 879 if (out_rank == 2) ss << " " << reduction_axes[1]; 880 LOG(FATAL) << ss.str(); 881 } 882 } 883 884 template <typename Reducer> 885 struct ReduceFunctor<GPUDevice, Reducer> { 886 template <typename OUT_T, typename IN_T, typename ReductionAxes> 887 static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, 888 const ReductionAxes& reduction_axes, 889 const Reducer& reducer); 890 }; 891 892 template <typename T> 893 struct ReduceFunctor<GPUDevice, Eigen::internal::SumReducer<T>> { 894 template <typename OUT_T, typename IN_T, typename ReductionAxes> 895 static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, 896 const ReductionAxes& reduction_axes, 897 const Eigen::internal::SumReducer<T>& reducer) { 898 ReduceImpl<T, Sum<T>, T*, T*, ReductionAxes>( 899 ctx, (T*)out.data(), (T*)in.data(), in.rank(), in.dimension(0), 900 in.rank() >= 2 ? in.dimension(1) : 1, 901 in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, 902 Sum<T>()); 903 } 904 905 template <typename OUT_T> 906 static void FillIdentity(const GPUDevice& d, OUT_T out, 907 const Eigen::internal::SumReducer<T>& reducer) { 908 FillIdentityEigenImpl(d, To32Bit(out), reducer); 909 } 910 }; 911 912 // TODO(rmlarsen): Specialize for float16. 913 template <typename T> 914 struct ReduceFunctor<GPUDevice, functor::EuclideanNormReducer<T>> { 915 template <typename OUT_T, typename IN_T, typename ReductionAxes> 916 static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, 917 const ReductionAxes& reduction_axes, 918 const functor::EuclideanNormReducer<T>& reducer) { 919 typedef cub::TransformInputIterator<T, Square<T>, T*> inputIterType; 920 inputIterType input_itr((T*)in.data(), Square<T>()); 921 typedef TransformOutputIterator<T, T, Sqrt<T>> outputIterType; 922 outputIterType output_itr((T*)out.data(), Sqrt<T>()); 923 ReduceImpl<T, Sum<T>, outputIterType, inputIterType, ReductionAxes>( 924 ctx, output_itr, input_itr, in.rank(), in.dimension(0), 925 in.rank() >= 2 ? in.dimension(1) : 1, 926 in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, 927 Sum<T>()); 928 } 929 930 template <typename OUT_T> 931 static void FillIdentity(const GPUDevice& d, OUT_T out, 932 const functor::EuclideanNormReducer<T>& reducer) { 933 FillIdentityEigenImpl(d, To32Bit(out), reducer); 934 } 935 }; 936 937 template <typename T> 938 struct ReduceFunctor<GPUDevice, functor::MeanReducer<T>> { 939 template <typename OUT_T, typename IN_T, typename ReductionAxes> 940 static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, 941 const ReductionAxes& reduction_axes, 942 const functor::MeanReducer<T>& reducer) { 943 int divisor = 1; 944 if (out.rank() == 0) 945 divisor = in.size(); 946 else if (out.rank() == 1 && in.rank() == 2 && reduction_axes[0] == 0) 947 divisor = in.dimension(0); 948 else if (out.rank() == 1 && in.rank() == 2 && reduction_axes[0] == 1) 949 divisor = in.dimension(1); 950 else if (out.rank() == 1 && in.rank() == 3 && reduction_axes[0] == 0 && 951 reduction_axes[1] == 2) 952 divisor = in.dimension(0) * in.dimension(2); 953 else if (out.rank() == 2 && in.rank() == 3 && reduction_axes[0] == 1) 954 divisor = in.dimension(1); 955 956 DividesBy<T> div_op(static_cast<T>(divisor)); 957 TransformOutputIterator<T, T, DividesBy<T>> itr((T*)out.data(), div_op); 958 ReduceImpl<T, Sum<T>, TransformOutputIterator<T, T, DividesBy<T>>, T*, 959 ReductionAxes>(ctx, itr, (T*)in.data(), in.rank(), 960 in.dimension(0), 961 in.rank() >= 2 ? in.dimension(1) : 1, 962 in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), 963 reduction_axes, Sum<T>()); 964 } 965 966 template <typename OUT_T> 967 static void FillIdentity(const GPUDevice& d, OUT_T out, 968 const functor::MeanReducer<T>& reducer) { 969 FillIdentityEigenImpl(d, To32Bit(out), reducer); 970 } 971 }; 972 973 template <> 974 struct ReduceFunctor<GPUDevice, functor::MeanReducer<Eigen::half>> { 975 template <typename OUT_T, typename IN_T, typename ReductionAxes> 976 static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, 977 const ReductionAxes& reduction_axes, 978 const functor::MeanReducer<Eigen::half>& reducer) { 979 float divisor = 1.f; 980 if (out.rank() == 0) 981 divisor = in.size(); 982 else if (out.rank() == 1 && in.rank() == 2 && reduction_axes[0] == 0) 983 divisor = in.dimension(0); 984 else if (out.rank() == 1 && in.rank() == 2 && reduction_axes[0] == 1) 985 divisor = in.dimension(1); 986 else if (out.rank() == 1 && in.rank() == 3 && reduction_axes[0] == 0 && 987 reduction_axes[1] == 2) 988 divisor = in.dimension(0) * in.dimension(2); 989 else if (out.rank() == 2 && in.rank() == 3 && reduction_axes[0] == 1) 990 divisor = in.dimension(1); 991 DividesBy<float, Eigen::half> div_op(divisor); 992 993 typedef cub::TransformInputIterator<float, HalfToFloat, Eigen::half*> 994 inputIterType; 995 inputIterType input_itr((Eigen::half*)in.data(), HalfToFloat()); 996 997 typedef TransformOutputIterator<Eigen::half, float, 998 DividesBy<float, Eigen::half>> 999 outputIterType; 1000 outputIterType itr((Eigen::half*)out.data(), div_op); 1001 1002 ReduceImpl<float, cub::Sum, outputIterType, inputIterType, ReductionAxes>( 1003 ctx, itr, input_itr, in.rank(), in.dimension(0), 1004 in.rank() >= 2 ? in.dimension(1) : 1, 1005 in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, 1006 cub::Sum()); 1007 } 1008 1009 template <typename OUT_T> 1010 static void FillIdentity(const GPUDevice& d, OUT_T out, 1011 const functor::MeanReducer<Eigen::half>& reducer) { 1012 FillIdentityEigenImpl(d, To32Bit(out), reducer); 1013 } 1014 }; 1015 1016 template <typename T> 1017 struct ReduceFunctor<GPUDevice, Eigen::internal::MaxReducer<T>> { 1018 template <typename OUT_T, typename IN_T, typename ReductionAxes> 1019 static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, 1020 const ReductionAxes& reduction_axes, 1021 const Eigen::internal::MaxReducer<T>& reducer) { 1022 ReduceImpl<T, cub::Max, T*, T*, ReductionAxes>( 1023 ctx, (T*)out.data(), (T*)in.data(), in.rank(), in.dimension(0), 1024 in.rank() >= 2 ? in.dimension(1) : 1, 1025 in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, 1026 cub::Max()); 1027 } 1028 1029 template <typename OUT_T> 1030 static void FillIdentity(const GPUDevice& d, OUT_T out, 1031 const Eigen::internal::MaxReducer<T>& reducer) { 1032 FillIdentityEigenImpl(d, To32Bit(out), reducer); 1033 } 1034 }; 1035 1036 template <typename T> 1037 struct ReduceFunctor<GPUDevice, Eigen::internal::MinReducer<T>> { 1038 template <typename OUT_T, typename IN_T, typename ReductionAxes> 1039 static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, 1040 const ReductionAxes& reduction_axes, 1041 const Eigen::internal::MinReducer<T>& reducer) { 1042 ReduceImpl<T, cub::Min, T*, T*, ReductionAxes>( 1043 ctx, (T*)out.data(), (T*)in.data(), in.rank(), in.dimension(0), 1044 in.rank() >= 2 ? in.dimension(1) : 1, 1045 in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, 1046 cub::Min()); 1047 } 1048 1049 template <typename OUT_T> 1050 static void FillIdentity(const GPUDevice& d, OUT_T out, 1051 const Eigen::internal::MinReducer<T>& reducer) { 1052 FillIdentityEigenImpl(d, To32Bit(out), reducer); 1053 } 1054 }; 1055 1056 template <typename T> 1057 struct ReduceFunctor<GPUDevice, Eigen::internal::ProdReducer<T>> { 1058 template <typename OUT_T, typename IN_T, typename ReductionAxes> 1059 static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, 1060 const ReductionAxes& reduction_axes, 1061 const Eigen::internal::ProdReducer<T>& reducer) { 1062 ReduceImpl<T, Prod<T>, T*, T*, ReductionAxes>( 1063 ctx, (T*)out.data(), (T*)in.data(), in.rank(), in.dimension(0), 1064 in.rank() >= 2 ? in.dimension(1) : 1, 1065 in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, 1066 Prod<T>()); 1067 } 1068 1069 template <typename OUT_T> 1070 static void FillIdentity(const GPUDevice& d, OUT_T out, 1071 const Eigen::internal::ProdReducer<T>& reducer) { 1072 FillIdentityEigenImpl(d, To32Bit(out), reducer); 1073 } 1074 }; 1075 1076 template <> 1077 struct ReduceFunctor<GPUDevice, Eigen::internal::AndReducer> { 1078 template <typename OUT_T, typename IN_T, typename ReductionAxes> 1079 static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, 1080 const ReductionAxes& reduction_axes, 1081 const Eigen::internal::AndReducer& reducer) { 1082 ReduceImpl<bool, And, bool*, bool*, ReductionAxes>( 1083 ctx, (bool*)out.data(), (bool*)in.data(), in.rank(), in.dimension(0), 1084 in.rank() >= 2 ? in.dimension(1) : 1, 1085 in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, 1086 And()); 1087 } 1088 1089 template <typename OUT_T> 1090 static void FillIdentity(const GPUDevice& d, OUT_T out, 1091 const Eigen::internal::AndReducer& reducer) { 1092 FillIdentityEigenImpl(d, To32Bit(out), reducer); 1093 } 1094 }; 1095 1096 template <> 1097 struct ReduceFunctor<GPUDevice, Eigen::internal::OrReducer> { 1098 template <typename OUT_T, typename IN_T, typename ReductionAxes> 1099 static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, 1100 const ReductionAxes& reduction_axes, 1101 const Eigen::internal::OrReducer& reducer) { 1102 ReduceImpl<bool, Or, bool*, bool*, ReductionAxes>( 1103 ctx, (bool*)out.data(), (bool*)in.data(), in.rank(), in.dimension(0), 1104 in.rank() >= 2 ? in.dimension(1) : 1, 1105 in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, Or()); 1106 } 1107 1108 template <typename OUT_T> 1109 static void FillIdentity(const GPUDevice& d, OUT_T out, 1110 const Eigen::internal::OrReducer& reducer) { 1111 FillIdentityEigenImpl(d, To32Bit(out), reducer); 1112 } 1113 }; 1114 1115 } // namespace functor 1116 } // namespace tensorflow 1117 1118 #endif // GOOGLE_CUDA 1119 1120 #endif // TENSORFLOW_CORE_KERNELS_REDUCTION_GPU_KERNELS_CU_H_ 1121