1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h"
17
18 #include <functional>
19
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/types/optional.h"
22 #include "tensorflow/compiler/xla/primitive_util.h"
23 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
24 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
25 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
26 #include "tensorflow/compiler/xla/util.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/mutex.h"
29 #include "tensorflow/core/platform/types.h"
30 #include "tensorflow/stream_executor/blas.h"
31 #include "tensorflow/stream_executor/device_memory.h"
32
33 namespace xla {
34 namespace gpu {
35
GetGpuGemmConfig(const HloInstruction * gemm)36 GpuGemmConfig GetGpuGemmConfig(const HloInstruction *gemm) {
37 GpuGemmConfig config;
38 config.output_shape = gemm->shape();
39 config.lhs_shape = gemm->operand(0)->shape();
40 config.rhs_shape = gemm->operand(1)->shape();
41 auto backend_config_or = gemm->backend_config<GemmBackendConfig>();
42 config.backend_config = std::move(backend_config_or.ValueOrDie());
43 return config;
44 }
45
GemmThunk(ThunkInfo thunk_info,GpuGemmConfig config,const BufferAllocation::Slice & lhs_buffer,const BufferAllocation::Slice & rhs_buffer,const BufferAllocation::Slice & output_buffer,bool implements_whole_instruction)46 GemmThunk::GemmThunk(ThunkInfo thunk_info, GpuGemmConfig config,
47 const BufferAllocation::Slice &lhs_buffer,
48 const BufferAllocation::Slice &rhs_buffer,
49 const BufferAllocation::Slice &output_buffer,
50 bool implements_whole_instruction)
51 : Thunk(Kind::kGemm, thunk_info),
52 config_(std::move(config)),
53 lhs_buffer_(lhs_buffer),
54 rhs_buffer_(rhs_buffer),
55 output_buffer_(output_buffer),
56 implements_whole_instruction_(implements_whole_instruction) {}
57
ExecuteOnStream(const ExecuteParams & params)58 Status GemmThunk::ExecuteOnStream(const ExecuteParams ¶ms) {
59 auto get_device_address = [&](const BufferAllocation::Slice &slice) {
60 return params.buffer_allocations->GetDeviceAddress(slice);
61 };
62
63 VLOG(3) << "Running GEMM thunk";
64 se::DeviceMemoryBase lhs_data = get_device_address(lhs_buffer_);
65 se::DeviceMemoryBase rhs_data = get_device_address(rhs_buffer_);
66 se::DeviceMemoryBase output_data = get_device_address(output_buffer_);
67 return RunGemm(config_, lhs_data, rhs_data, output_data, params.stream,
68 implements_whole_instruction_, profile_index(),
69 params.profiler);
70 }
71
72 // This struct contains the metadata of a matrix, e.g., its base address and
73 // dimensions.
74 struct MatrixDescriptor {
75 se::DeviceMemoryBase data;
76 bool transpose; // Whether this matrix needs to be transposed.
77 int64 num_rows;
78 int64 num_cols;
79 };
80
81 template <typename Element, typename AlphaType>
DoGemmWithAlgorithm(int64 batch_size,MatrixDescriptor lhs_matrix,MatrixDescriptor rhs_matrix,MatrixDescriptor output_matrix,AlphaType alpha,double beta,se::Stream * stream,absl::optional<se::blas::AlgorithmType> algorithm,se::blas::ProfileResult * output_profile_result)82 static bool DoGemmWithAlgorithm(
83 int64 batch_size, MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
84 MatrixDescriptor output_matrix, AlphaType alpha, double beta,
85 se::Stream *stream, absl::optional<se::blas::AlgorithmType> algorithm,
86 se::blas::ProfileResult *output_profile_result) {
87 DCHECK(!output_matrix.transpose);
88
89 PrimitiveType type = primitive_util::NativeToPrimitiveType<Element>();
90
91 // Converts from an XLA PrimitiveType to a blas::ComputationType, which is
92 // used to specify the precision with which matmul computations should be
93 // performed, separately from the precision of the inputs and result.
94 se::blas::ComputationType computation_type;
95 switch (type) {
96 case F16:
97 // Use F32 as computation type for F16 as we currently only implement
98 // the cuDNN pseudo half configuration for half precision.
99 computation_type = se::blas::ComputationType::kF32;
100 break;
101 case F32:
102 computation_type = se::blas::ComputationType::kF32;
103 break;
104 case F64:
105 computation_type = se::blas::ComputationType::kF64;
106 break;
107 case C64:
108 computation_type = se::blas::ComputationType::kComplexF32;
109 break;
110 case C128:
111 computation_type = se::blas::ComputationType::kComplexF64;
112 break;
113 default:
114 return false;
115 }
116
117 se::DeviceMemory<Element> lhs_data(lhs_matrix.data);
118 se::DeviceMemory<Element> rhs_data(rhs_matrix.data);
119 se::DeviceMemory<Element> output_data(output_matrix.data);
120
121 auto lhs_transpose = lhs_matrix.transpose ? se::blas::Transpose::kTranspose
122 : se::blas::Transpose::kNoTranspose;
123 auto rhs_transpose = rhs_matrix.transpose ? se::blas::Transpose::kTranspose
124 : se::blas::Transpose::kNoTranspose;
125 auto k = lhs_matrix.transpose ? lhs_matrix.num_rows : lhs_matrix.num_cols;
126
127 if (algorithm) {
128 // Autotuning is disabled for batch_size != 1.
129 CHECK_EQ(1, batch_size);
130 return stream
131 ->ThenBlasGemmWithAlgorithm(
132 lhs_transpose, rhs_transpose, output_matrix.num_rows,
133 output_matrix.num_cols,
134 /*size of reduce dim=*/k,
135 /*alpha=*/static_cast<Element>(alpha), lhs_data,
136 /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data,
137 /*leading dim of RHS=*/rhs_matrix.num_rows,
138 /*beta=*/static_cast<Element>(beta), &output_data,
139 /*leading dim of output=*/output_matrix.num_rows, computation_type,
140 *algorithm, output_profile_result)
141 .ok();
142 }
143
144 if (batch_size != 1) {
145 int64 lhs_stride = lhs_matrix.num_rows * lhs_matrix.num_cols;
146 int64 rhs_stride = rhs_matrix.num_rows * rhs_matrix.num_cols;
147 int64 output_stride = output_matrix.num_rows * output_matrix.num_cols;
148 return stream
149 ->ThenBlasGemmStridedBatched(
150 lhs_transpose, rhs_transpose, output_matrix.num_rows,
151 output_matrix.num_cols, /*size of reduce dim=*/k,
152 /*alpha=*/alpha, lhs_data,
153 /*leading dim of LHS=*/lhs_matrix.num_rows, lhs_stride, rhs_data,
154 /*leading dim of RHS=*/rhs_matrix.num_rows, rhs_stride,
155 /*beta=*/beta, &output_data,
156 /*leading dim of output=*/output_matrix.num_rows, output_stride,
157 batch_size)
158 .ok();
159 }
160
161 return stream
162 ->ThenBlasGemm(
163 lhs_transpose, rhs_transpose, output_matrix.num_rows,
164 output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha,
165 lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data,
166 /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/beta,
167 &output_data, /*leading dim of output=*/output_matrix.num_rows)
168 .ok();
169 }
170
RunGemm(const GpuGemmConfig & gemm_config,se::DeviceMemoryBase lhs_buffer,se::DeviceMemoryBase rhs_buffer,se::DeviceMemoryBase output_buffer,se::Stream * stream,bool implements_whole_instruction,absl::optional<int64> profile_index,HloExecutionProfiler * profiler,se::blas::ProfileResult * profile_result,absl::optional<se::blas::AlgorithmType> algorithm)171 Status RunGemm(const GpuGemmConfig &gemm_config,
172 se::DeviceMemoryBase lhs_buffer, se::DeviceMemoryBase rhs_buffer,
173 se::DeviceMemoryBase output_buffer, se::Stream *stream,
174 bool implements_whole_instruction,
175 absl::optional<int64> profile_index,
176 HloExecutionProfiler *profiler,
177 se::blas::ProfileResult *profile_result,
178 absl::optional<se::blas::AlgorithmType> algorithm) {
179 VLOG(2) << "Executing a GemmThunk";
180
181 const Shape &output_shape = gemm_config.output_shape;
182 const Shape &lhs_shape = gemm_config.lhs_shape;
183 const Shape &rhs_shape = gemm_config.rhs_shape;
184 const GemmBackendConfig &backend_config = gemm_config.backend_config;
185
186 const DotDimensionNumbers &dim_nums = backend_config.dot_dimension_numbers();
187 CHECK_EQ(dim_nums.lhs_batch_dimensions_size(),
188 dim_nums.rhs_batch_dimensions_size());
189 CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, output_shape.rank());
190
191 int64 row_dim = dim_nums.lhs_batch_dimensions_size();
192 int64 col_dim = dim_nums.lhs_batch_dimensions_size() + 1;
193
194 int64 batch_size = backend_config.batch_size();
195
196 // Check that the batch dims don't cover the last two dims.
197 for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) {
198 CHECK_NE(row_dim, batch_dim);
199 CHECK_NE(col_dim, batch_dim);
200 }
201
202 // Verify that the non-batch dimensions are minor-most. This is required for
203 // efficient access.
204 for (const auto *shape : {&lhs_shape, &rhs_shape, &output_shape}) {
205 CHECK_LT(shape->layout().minor_to_major(row_dim), 2);
206 CHECK_LT(shape->layout().minor_to_major(col_dim), 2);
207 }
208
209 int64 output_num_rows = output_shape.dimensions(row_dim);
210 int64 output_num_cols = output_shape.dimensions(col_dim);
211
212 // BLAS gemm expects the inputs and the output are in column-major order.
213 // Therefore, we need to convert dot between row-major matrices to that
214 // between column-major matrices. The key insight for the conversion is that,
215 // in linear storage, matrix M in column-major order is identical to the
216 // transpose of M in row-major order. In other words,
217 //
218 // column-major(M) = row-major(M^T).
219 //
220 // Leveraging this insight, we can perform dot between row-major matrices as
221 // follows.
222 //
223 // row-major(C)
224 // = row-major(A x B) = column-major((A x B)^T) = column-major(B^T x A^T)
225 // = gemm(column-major(B^T), column-major(A^T))
226 // = gemm(row-major(B), row-major(A))
227 //
228 // Although we do not modify the content of A and B in linear memory, we
229 // should use the dimensions of B^T and A^T when calling gemm. For example,
230 // the leading dimension of the LHS matrix of gemm is the number of rows in
231 // B^T and thus the number of columns in B.
232 auto make_descriptor = [&](se::DeviceMemoryBase data, const Shape &shape,
233 bool transpose) -> MatrixDescriptor {
234 bool is_row_major = LayoutUtil::Minor(shape.layout(), row_dim) != 0;
235 bool layout_mismatch = LayoutUtil::Minor(shape.layout(), row_dim) !=
236 LayoutUtil::Minor(output_shape.layout(), row_dim);
237 return MatrixDescriptor{
238 data, static_cast<bool>(transpose ^ layout_mismatch),
239 shape.dimensions(row_dim + static_cast<int64>(is_row_major)),
240 shape.dimensions(row_dim + static_cast<int64>(!is_row_major))};
241 };
242
243 MatrixDescriptor lhs_matrix = make_descriptor(
244 lhs_buffer, lhs_shape, dim_nums.lhs_contracting_dimensions(0) == row_dim);
245 MatrixDescriptor rhs_matrix = make_descriptor(
246 rhs_buffer, rhs_shape, dim_nums.rhs_contracting_dimensions(0) == col_dim);
247 std::unique_ptr<ScopedInstructionProfiler> op_profiler =
248 profiler ? profiler->MakeScopedInstructionProfiler(
249 implements_whole_instruction ? profile_index : -1)
250 : nullptr;
251
252 if (LayoutUtil::Minor(output_shape.layout(), row_dim) != 0) {
253 std::swap(lhs_matrix, rhs_matrix);
254 std::swap(output_num_cols, output_num_rows);
255 }
256
257 const MatrixDescriptor output_matrix{output_buffer, /*needs_transpose=*/false,
258 output_num_rows, output_num_cols};
259 auto best_algorithm = [&]() -> absl::optional<se::blas::AlgorithmType> {
260 if (algorithm) {
261 return *algorithm;
262 }
263 if (backend_config.algorithm_case() ==
264 GemmBackendConfig::ALGORITHM_NOT_SET) {
265 return absl::nullopt;
266 }
267 return backend_config.selected_algorithm();
268 }();
269
270 complex128 alpha = {backend_config.alpha_real(), backend_config.alpha_imag()};
271 double beta = backend_config.beta();
272
273 bool launch_ok = [&]() {
274 switch (output_shape.element_type()) {
275 case F16:
276 CHECK_EQ(alpha.imag(), 0);
277 return DoGemmWithAlgorithm<Eigen::half, double>(
278 batch_size, lhs_matrix, rhs_matrix, output_matrix, alpha.real(),
279 beta, stream, best_algorithm,
280 /*output_profile_result=*/profile_result);
281 case F32:
282 CHECK_EQ(alpha.imag(), 0);
283 return DoGemmWithAlgorithm<float, double>(
284 batch_size, lhs_matrix, rhs_matrix, output_matrix, alpha.real(),
285 beta, stream, best_algorithm,
286 /*output_profile_result=*/profile_result);
287 case F64:
288 CHECK_EQ(alpha.imag(), 0);
289 return DoGemmWithAlgorithm<double, double>(
290 batch_size, lhs_matrix, rhs_matrix, output_matrix, alpha.real(),
291 beta, stream, best_algorithm,
292 /*output_profile_result=*/profile_result);
293 case C64:
294 return DoGemmWithAlgorithm<complex64, complex64>(
295 batch_size, lhs_matrix, rhs_matrix, output_matrix,
296 static_cast<complex64>(alpha), beta, stream, best_algorithm,
297 /*output_profile_result=*/profile_result);
298 case C128:
299 return DoGemmWithAlgorithm<complex128, complex128>(
300 batch_size, lhs_matrix, rhs_matrix, output_matrix, alpha, beta,
301 stream, best_algorithm,
302 /*output_profile_result=*/profile_result);
303 default:
304 return false;
305 }
306 }();
307
308 if (!launch_ok) {
309 return InternalError("Unable to launch cuBLAS gemm on stream %p", stream);
310 }
311 return Status::OK();
312 }
313
314 } // namespace gpu
315 } // namespace xla
316