• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if GOOGLE_CUDA
17 
18 #define EIGEN_USE_GPU
19 
20 #include "tensorflow/core/framework/register_types.h"
21 #include "tensorflow/core/framework/tensor_types.h"
22 #include "tensorflow/core/kernels/gpu_device_array_gpu.h"
23 #include "tensorflow/core/util/cuda_kernel_helper.h"
24 
25 namespace tensorflow {
26 
27 using GPUDevice = Eigen::GpuDevice;
28 
29 namespace {
30 
31 template <typename T>
DynamicStitchKernel(const int32 slice_size,const int32 output_size,GpuDeviceArrayStruct<int32> input_indices,GpuDeviceArrayStruct<const T * > input_ptrs,T * output)32 __global__ void DynamicStitchKernel(const int32 slice_size,
33                                     const int32 output_size,
34                                     GpuDeviceArrayStruct<int32> input_indices,
35                                     GpuDeviceArrayStruct<const T*> input_ptrs,
36                                     T* output) {
37   int32* data_indices = GetGpuDeviceArrayOnDevice(&input_indices);
38   const T** data_ptrs = GetGpuDeviceArrayOnDevice(&input_ptrs);
39   CUDA_1D_KERNEL_LOOP(output_index, output_size) {
40     const int32 slice_id = output_index / slice_size;
41     const int32 slice_offset = output_index % slice_size;
42     const int32 input_index = data_indices[slice_id];
43     if (input_index != -1) {
44       output[output_index] = ldg(data_ptrs[input_index] + slice_offset);
45     }
46   }
47 }
48 
49 }  // namespace
50 
51 template <typename T>
DynamicStitchGPUImpl(const Eigen::GpuDevice & gpu_device,const int32 slice_size,const int32 first_dim_size,const GpuDeviceArrayStruct<int> & input_indices,const GpuDeviceArrayStruct<const T * > & input_ptrs,T * output)52 void DynamicStitchGPUImpl(const Eigen::GpuDevice& gpu_device,
53                           const int32 slice_size, const int32 first_dim_size,
54                           const GpuDeviceArrayStruct<int>& input_indices,
55                           const GpuDeviceArrayStruct<const T*>& input_ptrs,
56                           T* output) {
57   const int32 output_size = first_dim_size * slice_size;
58   auto config = GetCudaLaunchConfig(output_size, gpu_device);
59 
60   TF_CHECK_OK(CudaLaunchKernel(DynamicStitchKernel<T>, config.block_count,
61                                config.thread_per_block, 0, gpu_device.stream(),
62                                slice_size, output_size, input_indices,
63                                input_ptrs, output));
64 }
65 
66 #define REGISTER_GPU(T)                                           \
67   template void DynamicStitchGPUImpl(                             \
68       const Eigen::GpuDevice& gpu_device, const int32 slice_size, \
69       const int32 first_dim_size,                                 \
70       const GpuDeviceArrayStruct<int32>& input_indices,           \
71       const GpuDeviceArrayStruct<const T*>& input_ptrs, T* output);
72 
73 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
74 TF_CALL_complex64(REGISTER_GPU);
75 TF_CALL_complex128(REGISTER_GPU);
76 TF_CALL_int64(REGISTER_GPU);
77 TF_CALL_int32(REGISTER_GPU)
78 
79 #undef REGISTER_GPU
80 
81 }  // namespace tensorflow
82 #endif  // GOOGLE_CUDA
83