• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_UTIL_GPU_LAUNCH_CONFIG_H_
17 #define TENSORFLOW_CORE_UTIL_GPU_LAUNCH_CONFIG_H_
18 
19 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
20 
21 #include <algorithm>
22 
23 #include "absl/base/casts.h"
24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/stream_executor.h"
28 #include "tensorflow/core/platform/types.h"
29 #include "tensorflow/core/util/gpu_cuda_alias.h"
30 
31 // Usage of GetGpuLaunchConfig, GetGpu2DLaunchConfig, and
32 // GetGpu3DLaunchConfig:
33 //
34 // There are two versions of GetGpuLaunchConfig and GetGpu2DLaunchConfig, one
35 // version uses heuristics without any knowledge of the device kernel, the other
36 // version uses cudaOccupancyMaxPotentialBlockSize to determine the theoretical
37 // launch parameters that maximize occupancy. Currently, only the maximum
38 // occupancy version of GetGpu3DLaunchConfig is available.
39 //
40 // For large number of work elements, the convention is that each kernel would
41 // iterate through its assigned range. The return value of GetGpuLaunchConfig
42 // is struct GpuLaunchConfig, which contains all the information needed for the
43 // kernel launch, including: virtual number of threads, the number of threads
44 // per block and number of threads per block used inside <<< >>> of a kernel
45 // launch. GetGpu2DLaunchConfig and GetGpu3DLaunchConfig does the same thing
46 // as GpuLaunchConfig. The only difference is the dimension. The macros
47 // GPU_1D_KERNEL_LOOP and GPU_AXIS_KERNEL_LOOP might be used to do inner loop.
48 //
49 /* Sample code:
50 
51 __global__ void MyKernel1D(GpuLaunchConfig config, other_args...) {
52   GPU_1D_KERNEL_LOOP(x, config.virtual_thread_count) {
53     do_your_job_here;
54   }
55 }
56 
57 __global__ void MyKernel2D(Gpu2DLaunchConfig config, other_args...) {
58   GPU_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
59     GPU_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
60       do_your_job_here;
61     }
62   }
63 }
64 
65 __global__ void MyKernel3D(Gpu3DLaunchConfig config, other_args...) {
66   GPU_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
67     GPU_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
68       GPU_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) {
69         do_your_job_here;
70       }
71     }
72   }
73 }
74 
75 void MyDriverFunc(const Eigen::GpuDevice &d) {
76   // use heuristics
77   GpuLaunchConfig cfg1 = GetGpuLaunchConfig(10240, d);
78   MyKernel1D <<<config.block_count,
79                 config.thread_per_block, 0, d.stream()>>> (cfg1, other_args...);
80   Gpu2DLaunchConfig cfg2 = GetGpu2DLaunchConfig(10240, 10240, d);
81   MyKernel2D <<<config.block_count,
82                 config.thread_per_block, 0, d.stream()>>> (cfg2, other_args...);
83   Gpu3DLaunchConfig cfg3 = GetGpu3DLaunchConfig(4096, 4096, 100, d);
84   MyKernel3D <<<config.block_count,
85                 config.thread_per_block, 0, d.stream()>>> (cfg3, other_args...);
86 
87   // maximize occupancy
88   GpuLaunchConfig cfg4 = GetGpuLaunchConfig(10240, d, MyKernel1D, 0, 0 );
89   MyKernel1D <<<config.block_count,
90                 config.thread_per_block, 0, d.stream()>>> (cfg4, other_args...);
91   Gpu2DLaunchConfig cfg5 = GetGpu2DLaunchConfig(10240, 10240, d,
92                                                   MyKernel1D, 0, 0);
93   MyKernel2D <<<config.block_count,
94                 config.thread_per_block, 0, d.stream()>>> (cfg5, other_args...);
95   Gpu3DLaunchConfig cfg6 = GetGpu3DLaunchConfig(4096, 4096, 100, d,
96                                                   MyKernel1D, 0, 0);
97   MyKernel3D <<<config.block_count,
98                 config.thread_per_block, 0, d.stream()>>> (cfg6, other_args...);
99 }
100 
101 // See the test for this for more example:
102 //
103 https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/gpu_kernel_helper_test.cu.cc
104 
105 */
106 
107 namespace tensorflow {
108 
DivUp(int a,int b)109 inline int DivUp(int a, int b) { return (a + b - 1) / b; }
110 
111 struct GpuLaunchConfig {
112   // Logical number of thread that works on the elements. If each logical
113   // thread works on exactly a single element, this is the same as the working
114   // element count.
115   int virtual_thread_count = -1;
116   // Number of threads per block.
117   int thread_per_block = -1;
118   // Number of blocks for GPU kernel launch.
119   int block_count = -1;
120 };
121 CREATE_CUDA_TYPE_ALIAS(GpuLaunchConfig, CudaLaunchConfig);
122 
123 // Calculate the GPU launch config we should use for a kernel launch.
124 // This is assuming the kernel is quite simple and will largely be
125 // memory-limited.
126 // REQUIRES: work_element_count > 0.
GetGpuLaunchConfig(int work_element_count,const Eigen::GpuDevice & d)127 inline GpuLaunchConfig GetGpuLaunchConfig(int work_element_count,
128                                           const Eigen::GpuDevice& d) {
129   CHECK_GT(work_element_count, 0);
130   GpuLaunchConfig config;
131   const int virtual_thread_count = work_element_count;
132   const int physical_thread_count = std::min(
133       d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor(),
134       virtual_thread_count);
135   const int thread_per_block = std::min(1024, d.maxGpuThreadsPerBlock());
136   const int block_count =
137       std::min(DivUp(physical_thread_count, thread_per_block),
138                d.getNumGpuMultiProcessors());
139 
140   config.virtual_thread_count = virtual_thread_count;
141   config.thread_per_block = thread_per_block;
142   config.block_count = block_count;
143   return config;
144 }
145 #ifndef TENSORFLOW_USE_ROCM
GetCudaLaunchConfig(int work_element_count,const Eigen::GpuDevice & d)146 inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
147                                             const Eigen::GpuDevice& d) {
148   return GetGpuLaunchConfig(work_element_count, d);
149 }
150 #endif
151 
152 // Calculate the GPU launch config we should use for a kernel launch. This
153 // variant takes the resource limits of func into account to maximize occupancy.
154 // REQUIRES: work_element_count > 0.
155 template <typename DeviceFunc>
GetGpuLaunchConfig(int work_element_count,const Eigen::GpuDevice & d,DeviceFunc func,size_t dynamic_shared_memory_size,int block_size_limit)156 GpuLaunchConfig GetGpuLaunchConfig(int work_element_count,
157                                    const Eigen::GpuDevice& d, DeviceFunc func,
158                                    size_t dynamic_shared_memory_size,
159                                    int block_size_limit) {
160   CHECK_GT(work_element_count, 0);
161   GpuLaunchConfig config;
162   int block_count = 0;
163   int thread_per_block = 0;
164 
165 #if GOOGLE_CUDA
166   cudaError_t err = cudaOccupancyMaxPotentialBlockSize(
167       &block_count, &thread_per_block, func, dynamic_shared_memory_size,
168       block_size_limit);
169   CHECK_EQ(err, cudaSuccess);
170 #elif TENSORFLOW_USE_ROCM
171   // Earlier versions of this HIP routine incorrectly returned void.
172   // TODO re-enable hipError_t error checking when HIP is fixed.
173   // ROCm interface uses unsigned int, convert after checking
174   uint32_t block_count_uint = 0;
175   uint32_t thread_per_block_uint = 0;
176   CHECK_GE(block_size_limit, 0);
177   uint32_t block_size_limit_uint = static_cast<uint32_t>(block_size_limit);
178   hipOccupancyMaxPotentialBlockSize(&block_count_uint, &thread_per_block_uint,
179                                     func, dynamic_shared_memory_size,
180                                     block_size_limit_uint);
181   block_count = static_cast<int>(block_count_uint);
182   thread_per_block = static_cast<int>(thread_per_block_uint);
183 #endif
184 
185   block_count =
186       std::min(block_count, DivUp(work_element_count, thread_per_block));
187 
188   config.virtual_thread_count = work_element_count;
189   config.thread_per_block = thread_per_block;
190   config.block_count = block_count;
191   return config;
192 }
193 CREATE_CUDA_HOST_FUNCTION_ALIAS(GetGpuLaunchConfig, GetCudaLaunchConfig);
194 
195 // Calculate the GPU launch config we should use for a kernel launch. This
196 // variant takes the resource limits of func into account to maximize occupancy.
197 // The returned launch config has thread_per_block set to fixed_block_size.
198 // REQUIRES: work_element_count > 0.
199 template <typename DeviceFunc>
GetGpuLaunchConfigFixedBlockSize(int work_element_count,const Eigen::GpuDevice & d,DeviceFunc func,size_t dynamic_shared_memory_size,int fixed_block_size)200 GpuLaunchConfig GetGpuLaunchConfigFixedBlockSize(
201     int work_element_count, const Eigen::GpuDevice& d, DeviceFunc func,
202     size_t dynamic_shared_memory_size, int fixed_block_size) {
203   CHECK_GT(work_element_count, 0);
204   GpuLaunchConfig config;
205   int block_count = 0;
206 
207 #if GOOGLE_CUDA
208   cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
209       &block_count, func, fixed_block_size, dynamic_shared_memory_size);
210   CHECK_EQ(err, cudaSuccess);
211   block_count = std::min(block_count * d.getNumGpuMultiProcessors(),
212                          DivUp(work_element_count, fixed_block_size));
213 #elif TENSORFLOW_USE_ROCM
214   // ROCM TODO re-enable this after hipOccupancyMaxActiveBlocksPerMultiprocessor
215   // is implemented
216   // hipError_t err = hipOccupancyMaxActiveBlocksPerMultiprocessor(
217   //    &block_count, &thread_per_block, func, dynamic_shared_memory_size,
218   //    block_size_limit);
219   // CHECK_EQ(err, hipSuccess);
220 
221   // Apply the heuristic in GetGpuLaunchConfig(int, const Eigen::GpuDevice&)
222   // that the kernel is quite simple and will largely be memory-limited.
223   const int physical_thread_count = std::min(
224       d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor(),
225       work_element_count);
226   // Assume the kernel be simple enough that it is okay to use 1024 threads
227   // per workgroup.
228   int thread_per_block = std::min(1024, d.maxGpuThreadsPerBlock());
229   block_count = std::min(DivUp(physical_thread_count, thread_per_block),
230                          d.getNumGpuMultiProcessors());
231 #endif
232 
233   config.virtual_thread_count = work_element_count;
234   config.thread_per_block = fixed_block_size;
235   config.block_count = block_count;
236   return config;
237 }
238 CREATE_CUDA_HOST_FUNCTION_ALIAS(GetGpuLaunchConfigFixedBlockSize,
239                                 GetCudaLaunchConfigFixedBlockSize);
240 
241 struct Gpu2DLaunchConfig {
242   dim3 virtual_thread_count = dim3(0, 0, 0);
243   dim3 thread_per_block = dim3(0, 0, 0);
244   dim3 block_count = dim3(0, 0, 0);
245 };
246 CREATE_CUDA_TYPE_ALIAS(Gpu2DLaunchConfig, Cuda2DLaunchConfig);
247 
GetGpu2DLaunchConfig(int xdim,int ydim,const Eigen::GpuDevice & d)248 inline Gpu2DLaunchConfig GetGpu2DLaunchConfig(int xdim, int ydim,
249                                               const Eigen::GpuDevice& d) {
250   Gpu2DLaunchConfig config;
251 
252   if (xdim <= 0 || ydim <= 0) {
253     return config;
254   }
255 
256   const int kThreadsPerBlock = 256;
257   int block_cols = std::min(xdim, kThreadsPerBlock);
258   // ok to round down here and just do more loops in the kernel
259   int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
260 
261   const int physical_thread_count =
262       d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor();
263 
264   const int max_blocks = std::max(physical_thread_count / kThreadsPerBlock, 1);
265 
266   config.virtual_thread_count = dim3(xdim, ydim, 1);
267   config.thread_per_block = dim3(block_cols, block_rows, 1);
268 
269   int grid_x = std::min(DivUp(xdim, block_cols), max_blocks);
270 
271   config.block_count = dim3(
272       grid_x, std::min(max_blocks / grid_x, std::max(ydim / block_rows, 1)), 1);
273   return config;
274 }
275 #ifndef TENSORFLOW_USE_ROCM
GetCuda2DLaunchConfig(int xdim,int ydim,const Eigen::GpuDevice & d)276 inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(int xdim, int ydim,
277                                                 const Eigen::GpuDevice& d) {
278   return GetGpu2DLaunchConfig(xdim, ydim, d);
279 }
280 #endif
281 
282 // Calculate the GPU 2D and 3D launch config we should use for a kernel launch.
283 // This variant takes the resource limits of func into account to maximize
284 // occupancy.
285 using Gpu3DLaunchConfig = Gpu2DLaunchConfig;
286 CREATE_CUDA_TYPE_ALIAS(Gpu3DLaunchConfig, Cuda3DLaunchConfig);
287 
288 template <typename DeviceFunc>
GetGpu3DLaunchConfig(int xdim,int ydim,int zdim,const Eigen::GpuDevice & d,DeviceFunc func,size_t dynamic_shared_memory_size,int block_size_limit)289 Gpu3DLaunchConfig GetGpu3DLaunchConfig(int xdim, int ydim, int zdim,
290                                        const Eigen::GpuDevice& d,
291                                        DeviceFunc func,
292                                        size_t dynamic_shared_memory_size,
293                                        int block_size_limit) {
294   Gpu3DLaunchConfig config;
295 
296   if (xdim <= 0 || ydim <= 0 || zdim <= 0) {
297     return config;
298   }
299 
300   int dev;
301 #if GOOGLE_CUDA
302   cudaGetDevice(&dev);
303   cudaDeviceProp deviceProp;
304   cudaGetDeviceProperties(&deviceProp, dev);
305 #elif TENSORFLOW_USE_ROCM
306   hipGetDevice(&dev);
307   hipDeviceProp_t deviceProp;
308   hipGetDeviceProperties(&deviceProp, dev);
309 #endif
310   int xthreadlimit = deviceProp.maxThreadsDim[0];
311   int ythreadlimit = deviceProp.maxThreadsDim[1];
312   int zthreadlimit = deviceProp.maxThreadsDim[2];
313   int xgridlimit = deviceProp.maxGridSize[0];
314   int ygridlimit = deviceProp.maxGridSize[1];
315   int zgridlimit = deviceProp.maxGridSize[2];
316 
317   int block_count = 0;
318   int thread_per_block = 0;
319 
320 #if GOOGLE_CUDA
321   cudaError_t err = cudaOccupancyMaxPotentialBlockSize(
322       &block_count, &thread_per_block, func, dynamic_shared_memory_size,
323       block_size_limit);
324   CHECK_EQ(err, cudaSuccess);
325 #elif TENSORFLOW_USE_ROCM
326   // ROCM TODO re-enable this after hipOccupancyMaxPotentialBlockSize is
327   // implemented
328   // hipError_t err = hipOccupancyMaxPotentialBlockSize(
329   //    &block_count, &thread_per_block, func, dynamic_shared_memory_size,
330   //    block_size_limit);
331   // CHECK_EQ(err, hipSuccess);
332 
333   const int physical_thread_count =
334       d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor();
335   thread_per_block = std::min(1024, d.maxGpuThreadsPerBlock());
336   block_count = std::min(DivUp(physical_thread_count, thread_per_block),
337                          d.getNumGpuMultiProcessors());
338 #endif
339 
340   int threadsx = std::min({xdim, thread_per_block, xthreadlimit});
341   int threadsy =
342       std::min({ydim, std::max(thread_per_block / threadsx, 1), ythreadlimit});
343   int threadsz =
344       std::min({zdim, std::max(thread_per_block / (threadsx * threadsy), 1),
345                 zthreadlimit});
346 
347   int blocksx = std::min({block_count, DivUp(xdim, threadsx), xgridlimit});
348   int blocksy = std::min(
349       {DivUp(block_count, blocksx), DivUp(ydim, threadsy), ygridlimit});
350   int blocksz = std::min({DivUp(block_count, (blocksx * blocksy)),
351                           DivUp(zdim, threadsz), zgridlimit});
352 
353   config.virtual_thread_count = dim3(xdim, ydim, zdim);
354   config.thread_per_block = dim3(threadsx, threadsy, threadsz);
355   config.block_count = dim3(blocksx, blocksy, blocksz);
356   return config;
357 }
358 CREATE_CUDA_HOST_FUNCTION_ALIAS(GetGpu3DLaunchConfig, GetCuda3DLaunchConfig);
359 
360 template <typename DeviceFunc>
GetGpu2DLaunchConfig(int xdim,int ydim,const Eigen::GpuDevice & d,DeviceFunc func,size_t dynamic_shared_memory_size,int block_size_limit)361 Gpu2DLaunchConfig GetGpu2DLaunchConfig(int xdim, int ydim,
362                                        const Eigen::GpuDevice& d,
363                                        DeviceFunc func,
364                                        size_t dynamic_shared_memory_size,
365                                        int block_size_limit) {
366   return GetGpu3DLaunchConfig(xdim, ydim, 1, d, func,
367                               dynamic_shared_memory_size, block_size_limit);
368 }
369 CREATE_CUDA_HOST_FUNCTION_ALIAS(GetGpu2DLaunchConfig, GetCuda2DLaunchConfig);
370 
371 #if GOOGLE_CUDA
372 template <typename DeviceFunc>
GetCuda2DLaunchConfig(int xdim,int ydim,const Eigen::GpuDevice & d,DeviceFunc func,size_t dynamic_shared_memory_size,int block_size_limit)373 Cuda2DLaunchConfig GetCuda2DLaunchConfig(int xdim, int ydim,
374                                          const Eigen::GpuDevice& d,
375                                          DeviceFunc func,
376                                          size_t dynamic_shared_memory_size,
377                                          int block_size_limit) {
378   return GetGpu2DLaunchConfig(xdim, ydim, d, func, dynamic_shared_memory_size,
379                               block_size_limit);
380 }
381 #endif  // GOOGLE_CUDA
382 
383 namespace detail {
384 template <typename... Ts, size_t... Is>
GetArrayOfElementPointersImpl(std::tuple<Ts...> * tuple,absl::index_sequence<Is...>)385 std::array<void*, sizeof...(Ts)> GetArrayOfElementPointersImpl(
386     std::tuple<Ts...>* tuple, absl::index_sequence<Is...>) {
387   return {{&std::get<Is>(*tuple)...}};
388 }
389 // Returns an array of void pointers to the elements of the given tuple.
390 template <typename... Ts>
GetArrayOfElementPointers(std::tuple<Ts...> * tuple)391 std::array<void*, sizeof...(Ts)> GetArrayOfElementPointers(
392     std::tuple<Ts...>* tuple) {
393   return GetArrayOfElementPointersImpl(tuple,
394                                        absl::index_sequence_for<Ts...>{});
395 }
396 
397 template <bool...>
398 struct BoolPack;
399 template <bool... Bs>
400 using NoneTrue = std::is_same<BoolPack<Bs..., false>, BoolPack<false, Bs...>>;
401 // Returns whether none of the types in Ts is a reference.
402 template <typename... Ts>
NoneIsReference()403 constexpr bool NoneIsReference() {
404   return NoneTrue<(std::is_reference<Ts>::value)...>::value;
405 }
406 }  // namespace detail
407 }  // namespace tensorflow
408 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
409 #endif  // TENSORFLOW_CORE_UTIL_GPU_LAUNCH_CONFIG_H_
410