• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if GOOGLE_CUDA
17 
18 #define EIGEN_USE_GPU
19 
20 #include <memory>
21 #include <vector>
22 
23 #include "tensorflow/core/framework/bfloat16.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/tensor_types.h"
26 #include "tensorflow/core/kernels/concat_lib_gpu.h"
27 #include "tensorflow/core/kernels/gpu_device_array_gpu.h"
28 #include "tensorflow/core/util/cuda_kernel_helper.h"
29 
30 namespace tensorflow {
31 
32 typedef Eigen::GpuDevice GPUDevice;
33 
34 namespace {
35 
36 template <typename T, typename IntType>
concat_fixed_kernel(GpuDeviceArrayStruct<const T * > input_ptr_data,int split_size,int total_rows,int total_cols,T * output)37 __global__ void concat_fixed_kernel(
38     GpuDeviceArrayStruct<const T*> input_ptr_data, int split_size,
39     int total_rows, int total_cols, T* output) {
40   const T** input_ptrs = GetGpuDeviceArrayOnDevice(&input_ptr_data);
41   IntType gidx = blockIdx.x * blockDim.x + threadIdx.x;
42 
43   for (; gidx < total_cols; gidx += blockDim.x * gridDim.x) {
44     IntType gidy = blockIdx.y * blockDim.y + threadIdx.y;
45 
46     IntType split = gidx / split_size;
47     const T* input_ptr = input_ptrs[split];
48     IntType col_offset = gidx % split_size;
49 #pragma unroll
50     for (; gidy < total_rows; gidy += blockDim.y * gridDim.y) {
51       output[gidy * total_cols + gidx] =
52           input_ptr[gidy * split_size + col_offset];
53     }
54   }
55 }
56 
57 }  // end namespace
58 
59 // cannot be in anonymous namespace due to extern shared memory
60 template <typename T, typename IntType, bool useSmem>
concat_variable_kernel(GpuDeviceArrayStruct<const T * > input_ptr_data,GpuDeviceArrayStruct<IntType> output_scan,IntType total_rows,IntType total_cols,T * output)61 __global__ void concat_variable_kernel(
62     GpuDeviceArrayStruct<const T*> input_ptr_data,
63     GpuDeviceArrayStruct<IntType> output_scan, IntType total_rows,
64     IntType total_cols, T* output) {
65   const T** input_ptrs = GetGpuDeviceArrayOnDevice(&input_ptr_data);
66   IntType* col_scan = GetGpuDeviceArrayOnDevice(&output_scan);
67 
68   // do upper_bound on col to find which pointer we should be using
69   IntType gidx = blockIdx.x * blockDim.x + threadIdx.x;
70   IntType num_inputs = input_ptr_data.size;
71 
72   // verbose declaration needed due to template
73   extern __shared__ __align__(sizeof(T)) unsigned char smem[];
74   IntType* smem_col_scan = reinterpret_cast<IntType*>(smem);
75 
76   if (useSmem) {
77     IntType lidx = threadIdx.y * blockDim.x + threadIdx.x;
78     IntType blockSize = blockDim.x * blockDim.y;
79 
80     for (IntType i = lidx; i < output_scan.size; i += blockSize) {
81       smem_col_scan[i] = col_scan[i];
82     }
83 
84     __syncthreads();
85 
86     col_scan = smem_col_scan;
87   }
88 
89   // do an initial binary search and then scan linearly from there
90   // works well when there are many small segments and when the
91   // segments are much longer
92   IntType segment =
93       cuda_helper::upper_bound<IntType>(col_scan, num_inputs, gidx) - 1;
94 
95   IntType curr_offset = col_scan[segment];
96   IntType curr_segment = segment;
97   for (; gidx < total_cols; gidx += blockDim.x * gridDim.x) {
98     IntType curr_col_offset;
99     while ((curr_col_offset = col_scan[curr_segment + 1]) <= gidx) {
100       curr_offset = curr_col_offset;
101       ++curr_segment;
102     }
103 
104     IntType local_col = gidx - curr_offset;
105     IntType segment_width = curr_col_offset - curr_offset;
106     const T* input_ptr = input_ptrs[curr_segment];
107 
108     IntType gidy = blockIdx.y * blockDim.y + threadIdx.y;
109     for (; gidy < total_rows; gidy += blockDim.y * gridDim.y)
110       output[gidy * total_cols + gidx] =
111           input_ptr[gidy * segment_width + local_col];
112   }
113 }
114 
115 template <typename T, typename IntType>
ConcatGPUSlice(const Eigen::GpuDevice & gpu_device,const std::vector<std::unique_ptr<typename TTypes<T,2>::ConstMatrix>> & inputs_flat,typename TTypes<T,2>::Matrix * output)116 void ConcatGPUSlice(
117     const Eigen::GpuDevice& gpu_device,
118     const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&
119         inputs_flat,
120     typename TTypes<T, 2>::Matrix* output) {
121   Eigen::array<IntType, 2> offset{0, 0};
122   for (int i = 0; i < inputs_flat.size(); ++i) {
123     Eigen::array<IntType, 2> size;
124     size[0] = inputs_flat[i]->dimension(0);
125     size[1] = inputs_flat[i]->dimension(1);
126     if (std::is_same<IntType, int32>::value) {
127       To32Bit(*output).slice(offset, size).device(gpu_device) =
128           To32Bit(*inputs_flat[i]);
129     } else {
130       output->slice(offset, size).device(gpu_device) = *inputs_flat[i];
131     }
132 
133     offset[1] += size[1];
134   }
135 }
136 
137 template <typename T, typename IntType>
ConcatGPUImpl(const Eigen::GpuDevice & gpu_device,const GpuDeviceArrayStruct<const T * > & input_ptrs,const GpuDeviceArrayStruct<IntType> & output_scan,bool fixed_size,int split_size,typename TTypes<T,2>::Matrix * output)138 void ConcatGPUImpl(const Eigen::GpuDevice& gpu_device,
139                    const GpuDeviceArrayStruct<const T*>& input_ptrs,
140                    const GpuDeviceArrayStruct<IntType>& output_scan,
141                    bool fixed_size, int split_size,
142                    typename TTypes<T, 2>::Matrix* output) {
143   auto config = GetCuda2DLaunchConfig(output->dimension(1),
144                                       output->dimension(0), gpu_device);
145 
146   if (fixed_size) {
147     concat_fixed_kernel<T, IntType>
148         <<<config.block_count, config.thread_per_block, 0,
149            gpu_device.stream()>>>(input_ptrs, split_size, output->dimension(0),
150                                   output->dimension(1), output->data());
151   } else {
152     IntType smem_max = gpu_device.sharedMemPerBlock();
153     IntType smem_usage = output_scan.size * sizeof(IntType);
154     // performance crossover is less than using maximum available shared memory
155     // on most processors
156     // possibly due to decreasing occupancy
157     // 4096 inputs is a lot, most code will take the smem path
158     const int32 kMaxSmemBytesPerformance = 16384;
159     if (smem_usage < smem_max && smem_usage < kMaxSmemBytesPerformance)
160       concat_variable_kernel<T, IntType, true>
161           <<<config.block_count, config.thread_per_block, smem_usage,
162              gpu_device.stream()>>>(input_ptrs, output_scan,
163                                     output->dimension(0), output->dimension(1),
164                                     output->data());
165     else
166       concat_variable_kernel<T, IntType, false>
167           <<<config.block_count, config.thread_per_block, 0,
168              gpu_device.stream()>>>(input_ptrs, output_scan,
169                                     output->dimension(0), output->dimension(1),
170                                     output->data());
171   }
172 }
173 
174 #define REGISTER_GPUCONCAT32(T)                                               \
175   template void ConcatGPUSlice<T, int32>(                                     \
176       const Eigen::GpuDevice& gpu_device,                                     \
177       const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& \
178           inputs_flat,                                                        \
179       typename TTypes<T, 2>::Matrix* output);
180 
181 #define REGISTER_GPUCONCAT64(T)                                               \
182   template void ConcatGPUSlice<T, int64>(                                     \
183       const Eigen::GpuDevice& gpu_device,                                     \
184       const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& \
185           inputs_flat,                                                        \
186       typename TTypes<T, 2>::Matrix* output);
187 
188 #define REGISTER_GPU32(T)                                              \
189   template void ConcatGPUImpl<T, int32>(                               \
190       const Eigen::GpuDevice& d,                                       \
191       const GpuDeviceArrayStruct<const T*>& input_ptrs,                \
192       const GpuDeviceArrayStruct<int32>& ptr_offsets, bool fixed_size, \
193       int split_size, typename TTypes<T, 2>::Matrix* output);
194 
195 #define REGISTER_GPU64(T)                                              \
196   template void ConcatGPUImpl<T, int64>(                               \
197       const Eigen::GpuDevice& d,                                       \
198       const GpuDeviceArrayStruct<const T*>& input_ptrs,                \
199       const GpuDeviceArrayStruct<int64>& ptr_offsets, bool fixed_size, \
200       int split_size, typename TTypes<T, 2>::Matrix* output);
201 
202 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPUCONCAT32);
203 TF_CALL_complex64(REGISTER_GPUCONCAT32);
204 TF_CALL_complex128(REGISTER_GPUCONCAT32);
205 TF_CALL_int32(REGISTER_GPUCONCAT32);  // Needed for TensorLists.
206 TF_CALL_int64(REGISTER_GPUCONCAT32);
207 TF_CALL_int16(REGISTER_GPUCONCAT32);
208 TF_CALL_uint8(REGISTER_GPUCONCAT32);
209 REGISTER_GPUCONCAT32(bfloat16);
210 REGISTER_GPUCONCAT32(bool);
211 
212 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPUCONCAT64);
213 TF_CALL_complex64(REGISTER_GPUCONCAT64);
214 TF_CALL_complex128(REGISTER_GPUCONCAT64);
215 TF_CALL_int32(REGISTER_GPUCONCAT64);  // Needed for TensorLists.
216 TF_CALL_int64(REGISTER_GPUCONCAT64);
217 TF_CALL_int16(REGISTER_GPUCONCAT64);
218 TF_CALL_uint8(REGISTER_GPUCONCAT64);
219 REGISTER_GPUCONCAT64(bfloat16);
220 REGISTER_GPUCONCAT64(bool);
221 
222 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU32);
223 TF_CALL_complex64(REGISTER_GPU32);
224 TF_CALL_complex128(REGISTER_GPU32);
225 TF_CALL_int32(REGISTER_GPU32);  // Needed for TensorLists.
226 TF_CALL_int64(REGISTER_GPU32);
227 TF_CALL_int16(REGISTER_GPU32);
228 TF_CALL_uint8(REGISTER_GPU32);
229 REGISTER_GPU32(bfloat16);
230 REGISTER_GPU32(bool);
231 
232 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU64);
233 TF_CALL_complex64(REGISTER_GPU64);
234 TF_CALL_complex128(REGISTER_GPU64);
235 TF_CALL_int32(REGISTER_GPU64);  // Needed for TensorLists.
236 TF_CALL_int64(REGISTER_GPU64);
237 TF_CALL_int16(REGISTER_GPU64);
238 TF_CALL_uint8(REGISTER_GPU64);
239 REGISTER_GPU64(bfloat16);
240 REGISTER_GPU64(bool);
241 
242 #undef REGISTER_GPUCONCAT32
243 #undef REGISTER_GPUCONCAT64
244 #undef REGISTER_GPU32
245 #undef REGISTER_GPU64
246 
247 }  // end namespace tensorflow
248 
249 #endif  // GOOGLE_CUDA
250