• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
17 
18 #include <ostream>
19 #include <string>
20 
21 #include "tensorflow/compiler/xla/shape_util.h"
22 #include "tensorflow/core/platform/logging.h"
23 
24 namespace xla {
25 namespace gpu {
26 
operator <<(std::ostream & out,const LaunchDimensions & launch_dims)27 std::ostream& operator<<(std::ostream& out,
28                          const LaunchDimensions& launch_dims) {
29   LaunchDimensions::Dim3D block_counts = launch_dims.block_counts();
30   LaunchDimensions::Dim3D thread_counts = launch_dims.thread_counts_per_block();
31   out << absl::StrFormat("[block: {%d, %d, %d}, thread: {%d, %d, %d}]",
32                          block_counts.x, block_counts.y, block_counts.z,
33                          thread_counts.x, thread_counts.y, thread_counts.z);
34   return out;
35 }
36 
ThreadsPerBlockLimit(GpuDeviceInfo gpu_device_info)37 static int64 ThreadsPerBlockLimit(GpuDeviceInfo gpu_device_info) {
38   int64 threads_per_block = gpu_device_info.threads_per_block_limit;
39   if (threads_per_block <= 0) {
40     static std::atomic<int64> log_count{0};
41     if (log_count.fetch_add(1) < 8) {
42       LOG(WARNING) << "Attempting to calculate launch dimensions for GPU "
43                       "without full information about its capabilities.  "
44                       "StreamExecutor's PopulateDeviceDescription should be "
45                       "updated for this device.";
46     }
47     threads_per_block = gpu_device_info.threads_per_warp;
48     if (threads_per_block == 0) {
49       // Fall back to *something* if we can't even get num threads per warp.
50       threads_per_block = 32;
51     }
52   }
53   return threads_per_block;
54 }
55 
56 // Calculates the launch dimensions used to invoke `hlo`.
CalculateLaunchDimensions(const Shape & shape,GpuDeviceInfo gpu_device_info,int unroll_factor,bool few_waves)57 LaunchDimensions CalculateLaunchDimensions(const Shape& shape,
58                                            GpuDeviceInfo gpu_device_info,
59                                            int unroll_factor, bool few_waves) {
60   int64 num_elements = ShapeUtil::ElementsIn(shape);
61   if (num_elements <= 1) {
62     return LaunchDimensions();
63   }
64 
65   CHECK_EQ(num_elements % unroll_factor, 0);
66   num_elements = num_elements / unroll_factor;
67 
68   // Since we don't do any inter-warp communication, we're free to choose any
69   // block size we want, subject to hardware constraints.  We choose the largest
70   // block size allowed, as empirically, this is a performance win on almost
71   // (but not all) benchmarks.
72   //
73   // My guess is that using a larger block size encourages ptxas to decrease
74   // per-thread register usage, thus allowing for higher occupancy, but I
75   // haven't verified this.
76   //
77   // TODO(jlebar): Investigate this further, and tune this heuristic so we can
78   // run faster on the few benchmarks where smaller block size helps.
79   int64 threads_per_block = ThreadsPerBlockLimit(gpu_device_info);
80   // We unroll kernels to make use of vectorized loads/stores. This means we
81   // need more registers to hold intermediate values. Reduce the number of
82   // blocks per thread to increase the number of registers available to ptxas.
83   // Make sure we still have a multiple of 32.
84   threads_per_block =
85       RoundUpToNearest(threads_per_block / unroll_factor, int64{32});
86   if (num_elements < threads_per_block) {
87     threads_per_block = num_elements;
88     VLOG(2) << "Update # of threads per block to the element count ("
89             << threads_per_block << ") because the latter is smaller.";
90   }
91 
92   int64 block_count = CeilOfRatio(num_elements, threads_per_block);
93   if (few_waves) {
94     int64 capped_threads_per_block = std::min<int64>(threads_per_block, 128);
95     int64 capped_block_count =
96         gpu_device_info.core_count *
97         (gpu_device_info.threads_per_core_limit / capped_threads_per_block);
98     // Do not increase the number of blocks. This can happens for
99     // small num_elements.
100     if (capped_block_count < block_count) {
101       threads_per_block = capped_threads_per_block;
102       block_count = capped_block_count;
103     }
104   }
105   VLOG(2) << absl::StrFormat(
106       "Initialized the block count to ceil(# of elements / threads per "
107       "block) = ceil(%d/%d) = %d",
108       num_elements, threads_per_block, block_count);
109 
110   return LaunchDimensions(block_count, threads_per_block);
111 }
112 
113 }  // namespace gpu
114 }  // namespace xla
115