• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
17 
18 #include <algorithm>
19 #include <ostream>
20 #include <string>
21 
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/core/platform/logging.h"
24 
25 namespace xla {
26 namespace gpu {
27 
operator <<(std::ostream & out,const LaunchDimensions & launch_dims)28 std::ostream& operator<<(std::ostream& out,
29                          const LaunchDimensions& launch_dims) {
30   LaunchDimensions::Dim3D block_counts = launch_dims.block_counts();
31   LaunchDimensions::Dim3D thread_counts = launch_dims.thread_counts_per_block();
32   out << absl::StrFormat("[block: {%d, %d, %d}, thread: {%d, %d, %d}]",
33                          block_counts.x, block_counts.y, block_counts.z,
34                          thread_counts.x, thread_counts.y, thread_counts.z);
35   return out;
36 }
37 
ThreadsPerBlockLimit(GpuDeviceInfo gpu_device_info)38 static int64 ThreadsPerBlockLimit(GpuDeviceInfo gpu_device_info) {
39   int64_t threads_per_block = gpu_device_info.threads_per_block_limit;
40   if (threads_per_block <= 0) {
41     static std::atomic<int64> log_count{0};
42     if (log_count.fetch_add(1) < 8) {
43       LOG(WARNING) << "Attempting to calculate launch dimensions for GPU "
44                       "without full information about its capabilities.  "
45                       "StreamExecutor's PopulateDeviceDescription should be "
46                       "updated for this device.";
47     }
48     threads_per_block = gpu_device_info.threads_per_warp;
49     if (threads_per_block == 0) {
50       // Fall back to *something* if we can't even get num threads per warp.
51       threads_per_block = 32;
52     }
53   }
54   return threads_per_block;
55 }
56 
ThreadsPerBlockRowVectorized(const Shape & shape,GpuDeviceInfo gpu_device_info,LaunchDimensionsConfig dim_config)57 int64 ThreadsPerBlockRowVectorized(const Shape& shape,
58                                    GpuDeviceInfo gpu_device_info,
59                                    LaunchDimensionsConfig dim_config) {
60   if (shape.dimensions().empty()) {
61     return -1;
62   }
63   int64_t threads_per_block_row_vectorized =
64       shape.dimensions().back() / dim_config.unroll_factor;
65   if (dim_config.row_vectorized &&
66       shape.dimensions().back() % dim_config.unroll_factor == 0 &&
67       // If the row size is a multiple of 256, then use the old code
68       // path that use a block size of 256. This give small speed up on V100.
69       // Vectorization of the row load was already happening.
70       (shape.dimensions().back() % 256) != 0 &&
71       // Do not trigger the row vectorized codepath if this create too
72       // small block size as this hurt performance.
73       (threads_per_block_row_vectorized >= 128 &&
74        threads_per_block_row_vectorized <=
75            gpu_device_info.threads_per_block_limit)) {
76     return threads_per_block_row_vectorized;
77   }
78   return -1;
79 }
80 
CalculateLaunchDimensions(const Shape & shape,GpuDeviceInfo gpu_device_info,LaunchDimensionsConfig dim_config)81 StatusOr<LaunchDimensions> CalculateLaunchDimensions(
82     const Shape& shape, GpuDeviceInfo gpu_device_info,
83     LaunchDimensionsConfig dim_config) {
84   int64_t num_elements = ShapeUtil::ElementsIn(shape);
85   if (num_elements <= 1) {
86     return LaunchDimensions();
87   }
88 
89   CHECK_EQ(num_elements % dim_config.unroll_factor, 0);
90   num_elements = num_elements / dim_config.unroll_factor;
91 
92   // Since we don't do any inter-warp communication, we're free to choose any
93   // block size we want, subject to hardware constraints.  We choose the largest
94   // block size allowed, as empirically, this is a performance win on almost
95   // (but not all) benchmarks.
96   //
97   // My guess is that using a larger block size encourages ptxas to decrease
98   // per-thread register usage, thus allowing for higher occupancy, but I
99   // haven't verified this.
100   //
101   // TODO(jlebar): Investigate this further, and tune this heuristic so we can
102   // run faster on the few benchmarks where smaller block size helps.
103   int64_t threads_per_block = ThreadsPerBlockLimit(gpu_device_info);
104   int64_t threads_per_block_row_vectorized =
105       ThreadsPerBlockRowVectorized(shape, gpu_device_info, dim_config);
106   if (threads_per_block_row_vectorized > 0) {
107     threads_per_block = threads_per_block_row_vectorized;
108     VLOG(2) << "Update # of threads per block to (" << threads_per_block
109             << ") to be row_vectorized.";
110   } else {
111     CHECK(!dim_config.row_vectorized);
112     // We unroll kernels to make use of vectorized loads/stores. This means we
113     // need more registers to hold intermediate values. Reduce the number of
114     // threads per block to increase the number of registers available to ptxas.
115     // Make sure we still have a multiple of 32.
116     threads_per_block = RoundUpToNearest(
117         threads_per_block / dim_config.unroll_factor, int64{32});
118     if (num_elements < threads_per_block) {
119       threads_per_block = num_elements;
120       VLOG(2) << "Update # of threads per block to the element count ("
121               << threads_per_block << ") because the latter is smaller.";
122     }
123   }
124 
125   int64_t block_count = CeilOfRatio(num_elements, threads_per_block);
126   if (dim_config.few_waves && !dim_config.row_vectorized) {
127     int64_t capped_threads_per_block = std::min<int64>(threads_per_block, 128);
128     int64_t capped_block_count =
129         gpu_device_info.core_count *
130         (gpu_device_info.threads_per_core_limit / capped_threads_per_block);
131     if (capped_block_count < block_count) {
132       threads_per_block = capped_threads_per_block;
133       block_count = capped_block_count;
134     }
135   } else if (dim_config.few_waves && dim_config.row_vectorized) {
136     int64_t capped_threads_per_block = std::min<int64>(threads_per_block, 128);
137     if (dim_config.row_vectorized) {
138       // Keep the threads_per_block found for row_vectorized.
139       capped_threads_per_block = threads_per_block;
140     }
141     int64_t min_block_count =
142         gpu_device_info.core_count *
143         (gpu_device_info.threads_per_core_limit / capped_threads_per_block);
144     int64_t capped_block_count = block_count;
145     // This multiple of 32 was tuned to not cause regression on multiple
146     // benchmarks.  It isn't a value that is optimal for all
147     // kernels. Maybe looking at the arithmetic intensity of the
148     // kernels can specialize the multiple per kernel.
149     while (capped_block_count > (32 * min_block_count)) {
150       capped_block_count /= 2;
151     }
152     // Do not increase the number of blocks. This can happens for
153     // small num_elements.
154     if (capped_block_count < block_count) {
155       threads_per_block = capped_threads_per_block;
156       block_count = capped_block_count;
157     }
158   }
159   if (gpu_device_info.block_dim_limit_x > 0 &&
160       block_count >= gpu_device_info.block_dim_limit_x) {
161     return tensorflow::errors::Unimplemented(
162         "Kernel launch needs more blocks (", block_count,
163         ") than allowed by hardware (", gpu_device_info.block_dim_limit_x,
164         ").");
165   }
166 
167   VLOG(2) << absl::StrFormat(
168       "Initialized the block count to ceil(# of elements / threads per "
169       "block) = ceil(%d/%d) = %d",
170       num_elements, threads_per_block, block_count);
171 
172   return LaunchDimensions(block_count, threads_per_block);
173 }
174 
175 }  // namespace gpu
176 }  // namespace xla
177