1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h"
17
18 #include <functional>
19
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/types/optional.h"
22 #include "tensorflow/compiler/xla/primitive_util.h"
23 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
24 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
25 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
26 #include "tensorflow/compiler/xla/util.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/mutex.h"
29 #include "tensorflow/core/platform/types.h"
30 #include "tensorflow/stream_executor/blas.h"
31 #include "tensorflow/stream_executor/device_memory.h"
32
33 namespace xla {
34 namespace gpu {
35
GemmThunk(const BufferAllocation::Slice & lhs_buffer,const BufferAllocation::Slice & rhs_buffer,const BufferAllocation::Slice & output_buffer,bool implements_whole_instruction,const HloInstruction * hlo_instruction,const GemmBackendConfig & backend_config)36 GemmThunk::GemmThunk(const BufferAllocation::Slice &lhs_buffer,
37 const BufferAllocation::Slice &rhs_buffer,
38 const BufferAllocation::Slice &output_buffer,
39 bool implements_whole_instruction,
40 const HloInstruction *hlo_instruction,
41 const GemmBackendConfig &backend_config)
42 : Thunk(Kind::kGemm, hlo_instruction),
43 lhs_buffer_(lhs_buffer),
44 rhs_buffer_(rhs_buffer),
45 output_buffer_(output_buffer),
46 implements_whole_instruction_(implements_whole_instruction),
47 backend_config_(backend_config) {}
48
ExecuteOnStream(const ExecuteParams & params)49 Status GemmThunk::ExecuteOnStream(const ExecuteParams ¶ms) {
50 auto get_device_address = [&](const BufferAllocation::Slice &slice) {
51 return params.buffer_allocations->GetDeviceAddress(slice);
52 };
53
54 VLOG(3) << "Running GEMM thunk on instruction: " << hlo_instruction();
55 se::DeviceMemoryBase lhs_data = get_device_address(lhs_buffer_);
56 se::DeviceMemoryBase rhs_data = get_device_address(rhs_buffer_);
57 se::DeviceMemoryBase output_data = get_device_address(output_buffer_);
58 return RunGemm(hlo_instruction(), backend_config_, lhs_data, rhs_data,
59 output_data, params.stream, implements_whole_instruction_,
60 params.profiler);
61 }
62
63 // This struct contains the metadata of a matrix, e.g., its base address and
64 // dimensions.
65 struct MatrixDescriptor {
66 se::DeviceMemoryBase data;
67 bool transpose; // Whether this matrix needs to be transposed.
68 int64 num_rows;
69 int64 num_cols;
70 };
71
72 template <typename Element, typename AlphaType>
DoGemmWithAlgorithm(int64 batch_size,MatrixDescriptor lhs_matrix,MatrixDescriptor rhs_matrix,MatrixDescriptor output_matrix,AlphaType alpha,double beta,se::Stream * stream,absl::optional<se::blas::AlgorithmType> algorithm,se::blas::ProfileResult * output_profile_result)73 static bool DoGemmWithAlgorithm(
74 int64 batch_size, MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
75 MatrixDescriptor output_matrix, AlphaType alpha, double beta,
76 se::Stream *stream, absl::optional<se::blas::AlgorithmType> algorithm,
77 se::blas::ProfileResult *output_profile_result) {
78 DCHECK(!output_matrix.transpose);
79
80 PrimitiveType type = primitive_util::NativeToPrimitiveType<Element>();
81
82 // Converts from an XLA PrimitiveType to a blas::ComputationType, which is
83 // used to specify the precision with which matmul computations should be
84 // performed, separately from the precision of the inputs and result.
85 se::blas::ComputationType computation_type = [&](PrimitiveType type) {
86 switch (type) {
87 case F16:
88 // Use F32 as computation type for F16 as we currently only implement
89 // the cuDNN pseudo half configuration for half precision.
90 return se::blas::ComputationType::kF32;
91 case F32:
92 return se::blas::ComputationType::kF32;
93 case F64:
94 return se::blas::ComputationType::kF64;
95 case C64:
96 return se::blas::ComputationType::kComplexF32;
97 case C128:
98 return se::blas::ComputationType::kComplexF64;
99 default:
100 LOG(FATAL) << "Unsupported type.";
101 }
102 }(type);
103
104 se::DeviceMemory<Element> lhs_data(lhs_matrix.data);
105 se::DeviceMemory<Element> rhs_data(rhs_matrix.data);
106 se::DeviceMemory<Element> output_data(output_matrix.data);
107
108 auto lhs_transpose = lhs_matrix.transpose ? se::blas::Transpose::kTranspose
109 : se::blas::Transpose::kNoTranspose;
110 auto rhs_transpose = rhs_matrix.transpose ? se::blas::Transpose::kTranspose
111 : se::blas::Transpose::kNoTranspose;
112 auto k = lhs_matrix.transpose ? lhs_matrix.num_rows : lhs_matrix.num_cols;
113
114 if (algorithm) {
115 // Autotuning is disabled for batch_size != 1.
116 CHECK_EQ(1, batch_size);
117 return stream
118 ->ThenBlasGemmWithAlgorithm(
119 lhs_transpose, rhs_transpose, output_matrix.num_rows,
120 output_matrix.num_cols,
121 /*size of reduce dim=*/k,
122 /*alpha=*/static_cast<Element>(alpha), lhs_data,
123 /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data,
124 /*leading dim of RHS=*/rhs_matrix.num_rows,
125 /*beta=*/static_cast<Element>(beta), &output_data,
126 /*leading dim of output=*/output_matrix.num_rows, computation_type,
127 *algorithm, output_profile_result)
128 .ok();
129 }
130
131 if (batch_size != 1) {
132 int64 lhs_stride = lhs_matrix.num_rows * lhs_matrix.num_cols;
133 int64 rhs_stride = rhs_matrix.num_rows * rhs_matrix.num_cols;
134 int64 output_stride = output_matrix.num_rows * output_matrix.num_cols;
135 return stream
136 ->ThenBlasGemmStridedBatched(
137 lhs_transpose, rhs_transpose, output_matrix.num_rows,
138 output_matrix.num_cols, /*size of reduce dim=*/k,
139 /*alpha=*/alpha, lhs_data,
140 /*leading dim of LHS=*/lhs_matrix.num_rows, lhs_stride, rhs_data,
141 /*leading dim of RHS=*/rhs_matrix.num_rows, rhs_stride,
142 /*beta=*/beta, &output_data,
143 /*leading dim of output=*/output_matrix.num_rows, output_stride,
144 batch_size)
145 .ok();
146 }
147
148 return stream
149 ->ThenBlasGemm(
150 lhs_transpose, rhs_transpose, output_matrix.num_rows,
151 output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha,
152 lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data,
153 /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/beta,
154 &output_data, /*leading dim of output=*/output_matrix.num_rows)
155 .ok();
156 }
157
RunGemm(const HloInstruction * gemm,const GemmBackendConfig & backend_config,se::DeviceMemoryBase lhs_buffer,se::DeviceMemoryBase rhs_buffer,se::DeviceMemoryBase output_buffer,se::Stream * stream,bool implements_whole_instruction,HloExecutionProfiler * profiler,se::blas::ProfileResult * profile_result,absl::optional<se::blas::AlgorithmType> algorithm)158 Status RunGemm(const HloInstruction *gemm,
159 const GemmBackendConfig &backend_config,
160 se::DeviceMemoryBase lhs_buffer, se::DeviceMemoryBase rhs_buffer,
161 se::DeviceMemoryBase output_buffer, se::Stream *stream,
162 bool implements_whole_instruction,
163 HloExecutionProfiler *profiler,
164 se::blas::ProfileResult *profile_result,
165 absl::optional<se::blas::AlgorithmType> algorithm) {
166 VLOG(2) << "Executing a GemmThunk";
167 CHECK(IsCublasGemm(*gemm));
168
169 const Shape &output_shape = gemm->shape();
170 const HloInstruction *lhs = gemm->operand(0);
171 const HloInstruction *rhs = gemm->operand(1);
172
173 const Shape &lhs_shape = lhs->shape();
174 const Shape &rhs_shape = rhs->shape();
175
176 const DotDimensionNumbers &dim_nums = backend_config.dot_dimension_numbers();
177 CHECK_EQ(dim_nums.lhs_batch_dimensions_size(),
178 dim_nums.rhs_batch_dimensions_size());
179 CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, output_shape.rank());
180
181 int64 row_dim = dim_nums.lhs_batch_dimensions_size();
182 int64 col_dim = dim_nums.lhs_batch_dimensions_size() + 1;
183
184 int64 batch_size = backend_config.batch_size();
185
186 // Check that the batch dims don't cover the last two dims.
187 for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) {
188 CHECK_NE(row_dim, batch_dim);
189 CHECK_NE(col_dim, batch_dim);
190 }
191
192 // Verify that the non-batch dimensions are minor-most. This is required for
193 // efficient access.
194 for (const auto *shape : {&lhs_shape, &rhs_shape, &output_shape}) {
195 CHECK_LT(shape->layout().minor_to_major(row_dim), 2);
196 CHECK_LT(shape->layout().minor_to_major(col_dim), 2);
197 }
198
199 // BLAS gemm reduces rows of LHS and columns of RHS. The Dot operator between
200 // matrices reduces dimension 1 of LHS and dimension 0 of RHS regardless of
201 // their layout. Therefore, we should treat dimension 0 as row and dimension 1
202 // as column when mapping a matrix Dot to BLAS gemm.
203 int64 output_num_rows = output_shape.dimensions(row_dim);
204 int64 output_num_cols = output_shape.dimensions(col_dim);
205
206 // BLAS gemm expects the inputs and the output are in column-major order.
207 // Therefore, we need to convert dot between row-major matrices to that
208 // between column-major matrices. The key insight for the conversion is that,
209 // in linear storage, matrix M in column-major order is identical to the
210 // transpose of M in row-major order. In other words,
211 //
212 // column-major(M) = row-major(M^T).
213 //
214 // Leveraging this insight, we can perform dot between row-major matrices as
215 // follows.
216 //
217 // row-major(C)
218 // = row-major(A x B) = column-major((A x B)^T) = column-major(B^T x A^T)
219 // = gemm(column-major(B^T), column-major(A^T))
220 // = gemm(row-major(B), row-major(A))
221 //
222 // Although we do not modify the content of A and B in linear memory, we
223 // should use the dimensions of B^T and A^T when calling gemm. For example,
224 // the leading dimension of the LHS matrix of gemm is the number of rows in
225 // B^T and thus the number of columns in B.
226 auto make_descriptor = [&](se::DeviceMemoryBase data, const Shape &shape,
227 bool transpose) -> MatrixDescriptor {
228 bool is_row_major = LayoutUtil::Minor(shape.layout(), row_dim) != 0;
229 bool layout_mismatch = LayoutUtil::Minor(shape.layout(), row_dim) !=
230 LayoutUtil::Minor(output_shape.layout(), row_dim);
231 return MatrixDescriptor{
232 data, static_cast<bool>(transpose ^ layout_mismatch),
233 shape.dimensions(row_dim + static_cast<int64>(is_row_major)),
234 shape.dimensions(row_dim + static_cast<int64>(!is_row_major))};
235 };
236
237 MatrixDescriptor lhs_matrix = make_descriptor(
238 lhs_buffer, lhs_shape, dim_nums.lhs_contracting_dimensions(0) == row_dim);
239 MatrixDescriptor rhs_matrix = make_descriptor(
240 rhs_buffer, rhs_shape, dim_nums.rhs_contracting_dimensions(0) == col_dim);
241 std::unique_ptr<ScopedInstructionProfiler> op_profiler =
242 profiler ? profiler->MakeScopedInstructionProfiler(
243 implements_whole_instruction ? gemm : nullptr)
244 : nullptr;
245
246 if (LayoutUtil::Minor(output_shape.layout(), row_dim) != 0) {
247 std::swap(lhs_matrix, rhs_matrix);
248 std::swap(output_num_cols, output_num_rows);
249 }
250
251 const MatrixDescriptor output_matrix{output_buffer, /*needs_transpose=*/false,
252 output_num_rows, output_num_cols};
253 auto best_algorithm = [&]() -> absl::optional<se::blas::AlgorithmType> {
254 if (algorithm) {
255 return *algorithm;
256 }
257 if (backend_config.algorithm_case() ==
258 GemmBackendConfig::ALGORITHM_NOT_SET) {
259 return absl::nullopt;
260 }
261 return backend_config.selected_algorithm();
262 }();
263
264 complex128 alpha = {backend_config.alpha_real(), backend_config.alpha_imag()};
265 double beta = backend_config.beta();
266
267 bool launch_ok = [&]() {
268 switch (output_shape.element_type()) {
269 case F16:
270 CHECK_EQ(alpha.imag(), 0);
271 return DoGemmWithAlgorithm<Eigen::half, double>(
272 batch_size, lhs_matrix, rhs_matrix, output_matrix, alpha.real(),
273 beta, stream, best_algorithm,
274 /*output_profile_result=*/profile_result);
275 case F32:
276 CHECK_EQ(alpha.imag(), 0);
277 return DoGemmWithAlgorithm<float, double>(
278 batch_size, lhs_matrix, rhs_matrix, output_matrix, alpha.real(),
279 beta, stream, best_algorithm,
280 /*output_profile_result=*/profile_result);
281 case F64:
282 CHECK_EQ(alpha.imag(), 0);
283 return DoGemmWithAlgorithm<double, double>(
284 batch_size, lhs_matrix, rhs_matrix, output_matrix, alpha.real(),
285 beta, stream, best_algorithm,
286 /*output_profile_result=*/profile_result);
287 case C64:
288 return DoGemmWithAlgorithm<complex64, complex64>(
289 batch_size, lhs_matrix, rhs_matrix, output_matrix,
290 static_cast<complex64>(alpha), beta, stream, best_algorithm,
291 /*output_profile_result=*/profile_result);
292 case C128:
293 return DoGemmWithAlgorithm<complex128, complex128>(
294 batch_size, lhs_matrix, rhs_matrix, output_matrix, alpha, beta,
295 stream, best_algorithm,
296 /*output_profile_result=*/profile_result);
297 default:
298 LOG(FATAL) << "Unsupported type.";
299 }
300 }();
301
302 if (!launch_ok) {
303 return InternalError("Unable to launch cuBLAS gemm on stream %p", stream);
304 }
305 return Status::OK();
306 }
307
308 } // namespace gpu
309 } // namespace xla
310