1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h"
17
18 #include <functional>
19
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/types/optional.h"
22 #include "tensorflow/compiler/xla/primitive_util.h"
23 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
24 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
25 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
26 #include "tensorflow/compiler/xla/util.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/mutex.h"
29 #include "tensorflow/core/platform/types.h"
30 #include "tensorflow/stream_executor/blas.h"
31 #include "tensorflow/stream_executor/device_memory.h"
32
33 namespace xla {
34 namespace gpu {
35
GetGpuGemmConfig(const HloInstruction * gemm)36 GpuGemmConfig GetGpuGemmConfig(const HloInstruction *gemm) {
37 GpuGemmConfig config;
38 config.output_shape = gemm->shape();
39 config.lhs_shape = gemm->operand(0)->shape();
40 config.rhs_shape = gemm->operand(1)->shape();
41 auto backend_config_or = gemm->backend_config<GemmBackendConfig>();
42 config.backend_config = std::move(backend_config_or.ValueOrDie());
43 return config;
44 }
45
GemmThunk(ThunkInfo thunk_info,GpuGemmConfig config,const BufferAllocation::Slice & lhs_buffer,const BufferAllocation::Slice & rhs_buffer,const BufferAllocation::Slice & output_buffer,bool implements_whole_instruction)46 GemmThunk::GemmThunk(ThunkInfo thunk_info, GpuGemmConfig config,
47 const BufferAllocation::Slice &lhs_buffer,
48 const BufferAllocation::Slice &rhs_buffer,
49 const BufferAllocation::Slice &output_buffer,
50 bool implements_whole_instruction)
51 : Thunk(Kind::kGemm, thunk_info),
52 config_(std::move(config)),
53 lhs_buffer_(lhs_buffer),
54 rhs_buffer_(rhs_buffer),
55 output_buffer_(output_buffer),
56 implements_whole_instruction_(implements_whole_instruction) {}
57
ExecuteOnStream(const ExecuteParams & params)58 Status GemmThunk::ExecuteOnStream(const ExecuteParams ¶ms) {
59 auto get_device_address = [&](const BufferAllocation::Slice &slice) {
60 return params.buffer_allocations->GetDeviceAddress(slice);
61 };
62
63 VLOG(3) << "Running GEMM thunk";
64 se::DeviceMemoryBase lhs_data = get_device_address(lhs_buffer_);
65 se::DeviceMemoryBase rhs_data = get_device_address(rhs_buffer_);
66 se::DeviceMemoryBase output_data = get_device_address(output_buffer_);
67 return RunGemm(config_, lhs_data, rhs_data, output_data, params.stream,
68 implements_whole_instruction_, profile_index());
69 }
70
71 // This struct contains the metadata of a matrix, e.g., its base address and
72 // dimensions.
73 struct MatrixDescriptor {
74 se::DeviceMemoryBase data;
75 se::blas::Transpose transpose;
76 int64 num_rows;
77 int64 num_cols;
78 int64 stride;
79
reduced_dimxla::gpu::MatrixDescriptor80 int64 reduced_dim() const {
81 return transpose == se::blas::Transpose::kTranspose ? num_rows : num_cols;
82 }
83
84 template <typename T>
castxla::gpu::MatrixDescriptor85 se::DeviceMemory<T> cast() const {
86 return se::DeviceMemory<T>(data);
87 }
88 };
89
90 // Converts from an XLA PrimitiveType to a blas::ComputationType, which is
91 // used to specify the precision with which matmul computations should be
92 // performed, separately from the precision of the inputs and result.
ComputationTypeFromPrimitive(PrimitiveType type)93 static absl::optional<se::blas::ComputationType> ComputationTypeFromPrimitive(
94 PrimitiveType type) {
95 switch (type) {
96 case F16:
97 case BF16:
98 return se::blas::ComputationType::kF32;
99 case F32:
100 return se::blas::ComputationType::kF32;
101 case F64:
102 return se::blas::ComputationType::kF64;
103 case C64:
104 return se::blas::ComputationType::kComplexF32;
105 case C128:
106 return se::blas::ComputationType::kComplexF64;
107 case S32:
108 return se::blas::ComputationType::kI32;
109 default:
110 return absl::nullopt;
111 }
112 }
113
114 template <typename Input, typename Output>
DoGemmWithAlgorithm(int64_t batch_size,MatrixDescriptor lhs,MatrixDescriptor rhs,MatrixDescriptor output_matrix,Output alpha,Output beta,se::Stream * stream,se::blas::AlgorithmType algorithm,se::blas::ProfileResult * output_profile_result)115 static Status DoGemmWithAlgorithm(
116 int64_t batch_size, MatrixDescriptor lhs, MatrixDescriptor rhs,
117 MatrixDescriptor output_matrix, Output alpha, Output beta,
118 se::Stream *stream, se::blas::AlgorithmType algorithm,
119 se::blas::ProfileResult *output_profile_result) {
120 CHECK(output_matrix.transpose == se::blas::Transpose::kNoTranspose);
121 PrimitiveType output_type = primitive_util::NativeToPrimitiveType<Output>();
122 se::blas::ComputationType computation_type =
123 *ComputationTypeFromPrimitive(output_type);
124 se::DeviceMemory<Output> output_data(output_matrix.data);
125
126 if (batch_size != 1) {
127 return stream->ThenBlasGemmStridedBatchedWithAlgorithm(
128 lhs.transpose, rhs.transpose, output_matrix.num_rows,
129 output_matrix.num_cols,
130 /*size of reduce dim=*/lhs.reduced_dim(),
131 /*alpha=*/alpha, lhs.cast<Input>(), lhs.stride,
132 /*leading dim of LHS=*/lhs.num_rows, rhs.cast<Input>(),
133 /*leading dim of RHS=*/rhs.num_rows, rhs.stride,
134 /*beta=*/beta, &output_data,
135 /*leading dim of output=*/output_matrix.num_rows, output_matrix.stride,
136 batch_size, computation_type, algorithm, output_profile_result);
137 } else {
138 return stream->ThenBlasGemmWithAlgorithm(
139 lhs.transpose, rhs.transpose, output_matrix.num_rows,
140 output_matrix.num_cols,
141 /*size of reduce dim=*/lhs.reduced_dim(),
142 /*alpha=*/alpha, lhs.cast<Input>(),
143 /*lda=*/lhs.num_rows, rhs.cast<Input>(),
144 /*ldb=*/rhs.num_rows,
145 /*beta=*/beta, &output_data,
146 /*ldc=*/output_matrix.num_rows, computation_type, algorithm,
147 output_profile_result);
148 }
149 }
150
151 template <typename Input>
DoGemm(int64_t batch_size,const MatrixDescriptor & lhs,const MatrixDescriptor & rhs,const MatrixDescriptor & output_matrix,Input alpha,Input beta,se::Stream * stream,absl::optional<se::blas::AlgorithmType> algorithm,se::blas::ProfileResult * output_profile_result)152 static Status DoGemm(int64_t batch_size, const MatrixDescriptor &lhs,
153 const MatrixDescriptor &rhs,
154 const MatrixDescriptor &output_matrix, Input alpha,
155 Input beta, se::Stream *stream,
156 absl::optional<se::blas::AlgorithmType> algorithm,
157 se::blas::ProfileResult *output_profile_result) {
158 CHECK(output_matrix.transpose == se::blas::Transpose::kNoTranspose);
159 se::DeviceMemory<Input> output_data(output_matrix.data);
160
161 if (algorithm) {
162 return DoGemmWithAlgorithm<Input, Input>(batch_size, lhs, rhs,
163 output_matrix, alpha, beta, stream,
164 *algorithm, output_profile_result);
165 }
166
167 if (batch_size != 1) {
168 return stream->ThenBlasGemmStridedBatched(
169 lhs.transpose, rhs.transpose, output_matrix.num_rows,
170 output_matrix.num_cols, /*size of reduce dim=*/lhs.reduced_dim(),
171 /*alpha=*/alpha, lhs.cast<Input>(),
172 /*leading dim of LHS=*/lhs.num_rows, lhs.stride, rhs.cast<Input>(),
173 /*leading dim of RHS=*/rhs.num_rows, rhs.stride,
174 /*beta=*/beta, &output_data,
175 /*leading dim of output=*/output_matrix.num_rows, output_matrix.stride,
176 batch_size);
177 }
178 return stream->ThenBlasGemm(
179 lhs.transpose, rhs.transpose, output_matrix.num_rows,
180 output_matrix.num_cols, /*size of reduce dim=*/lhs.reduced_dim(),
181 /*alpha=*/alpha, lhs.cast<Input>(),
182 /*leading dim of LHS=*/lhs.num_rows, rhs.cast<Input>(),
183 /*leading dim of RHS=*/rhs.num_rows,
184 /*beta=*/beta, &output_data,
185 /*leading dim of output=*/output_matrix.num_rows);
186 }
187
RunGemm(const GpuGemmConfig & gemm_config,se::DeviceMemoryBase lhs_buffer,se::DeviceMemoryBase rhs_buffer,se::DeviceMemoryBase output_buffer,se::Stream * stream,bool implements_whole_instruction,absl::optional<int64> profile_index,se::blas::ProfileResult * profile_result,absl::optional<se::blas::AlgorithmType> algorithm)188 Status RunGemm(const GpuGemmConfig &gemm_config,
189 se::DeviceMemoryBase lhs_buffer, se::DeviceMemoryBase rhs_buffer,
190 se::DeviceMemoryBase output_buffer, se::Stream *stream,
191 bool implements_whole_instruction,
192 absl::optional<int64> profile_index,
193 se::blas::ProfileResult *profile_result,
194 absl::optional<se::blas::AlgorithmType> algorithm) {
195 VLOG(2) << "Executing a GemmThunk";
196
197 const Shape &output_shape = gemm_config.output_shape;
198 const Shape &lhs_shape = gemm_config.lhs_shape;
199 const Shape &rhs_shape = gemm_config.rhs_shape;
200 const GemmBackendConfig &backend_config = gemm_config.backend_config;
201 const DotDimensionNumbers &dim_nums = backend_config.dot_dimension_numbers();
202 absl::Span<const int64> output_batch_dims =
203 AsInt64Slice((dim_nums.lhs_batch_dimensions_size() >
204 dim_nums.rhs_batch_dimensions_size())
205 ? dim_nums.lhs_batch_dimensions()
206 : dim_nums.rhs_batch_dimensions());
207
208 int64_t batch_size = backend_config.batch_size();
209 int64_t output_row_dim = output_batch_dims.size();
210 int64_t output_col_dim = output_row_dim + 1;
211
212 if (backend_config.rhs_stride() && backend_config.lhs_stride()) {
213 CHECK_EQ(dim_nums.lhs_batch_dimensions_size(),
214 dim_nums.rhs_batch_dimensions_size());
215 }
216
217 int64_t output_num_rows = output_shape.dimensions(output_row_dim);
218 int64_t output_num_cols = output_shape.dimensions(output_col_dim);
219
220 auto validate_matrix = [&](const Shape &shape, auto batch_dimensions) {
221 int64_t row_dim = batch_dimensions.size();
222 int64_t col_dim = row_dim + 1;
223 CHECK_EQ(row_dim + 2, shape.rank());
224
225 // Check that the batch dims don't cover the last two dims.
226 for (int64_t batch_dim : batch_dimensions) {
227 CHECK_NE(row_dim, batch_dim);
228 CHECK_NE(col_dim, batch_dim);
229 }
230
231 // Verify that the non-batch dimensions are minor-most. This is required for
232 // efficient access.
233 CHECK_LT(shape.layout().minor_to_major(row_dim), 2);
234 CHECK_LT(shape.layout().minor_to_major(col_dim), 2);
235 };
236
237 validate_matrix(lhs_shape, dim_nums.lhs_batch_dimensions());
238 validate_matrix(rhs_shape, dim_nums.rhs_batch_dimensions());
239 validate_matrix(output_shape, output_batch_dims);
240
241 // BLAS gemm expects the inputs and the output are in column-major order.
242 // Therefore, we need to convert dot between row-major matrices to that
243 // between column-major matrices. The key insight for the conversion is that,
244 // in linear storage, matrix M in column-major order is identical to the
245 // transpose of M in row-major order. In other words,
246 //
247 // column-major(M) = row-major(M^T).
248 //
249 // Leveraging this insight, we can perform dot between row-major matrices as
250 // follows.
251 //
252 // row-major(C)
253 // = row-major(A x B) = column-major((A x B)^T) = column-major(B^T x A^T)
254 // = gemm(column-major(B^T), column-major(A^T))
255 // = gemm(row-major(B), row-major(A))
256 //
257 // Although we do not modify the content of A and B in linear memory, we
258 // should use the dimensions of B^T and A^T when calling gemm. For example,
259 // the leading dimension of the LHS matrix of gemm is the number of rows in
260 // B^T and thus the number of columns in B.
261 auto make_descriptor = [&](se::DeviceMemoryBase data, const Shape &shape,
262 int64_t row_dim, bool transpose,
263 int64_t stride) -> MatrixDescriptor {
264 bool is_row_major = LayoutUtil::Minor(shape.layout(), row_dim) != 0;
265 bool layout_mismatch =
266 LayoutUtil::Minor(shape.layout(), row_dim) !=
267 LayoutUtil::Minor(output_shape.layout(), output_row_dim);
268 int64_t rows =
269 shape.dimensions(row_dim + static_cast<int64_t>(is_row_major));
270 int64_t cols =
271 shape.dimensions(row_dim + static_cast<int64_t>(!is_row_major));
272 if (stride != 0) {
273 CHECK_EQ(stride, rows * cols);
274 }
275 return MatrixDescriptor{data,
276 transpose ^ layout_mismatch
277 ? se::blas::Transpose::kTranspose
278 : se::blas::Transpose::kNoTranspose,
279 rows, cols, stride};
280 };
281
282 bool lhs_transpose = dim_nums.lhs_contracting_dimensions(0) ==
283 dim_nums.lhs_batch_dimensions_size();
284 bool rhs_transpose = dim_nums.rhs_contracting_dimensions(0) ==
285 dim_nums.rhs_batch_dimensions_size() + 1;
286
287 MatrixDescriptor lhs_matrix = make_descriptor(
288 lhs_buffer, lhs_shape, dim_nums.lhs_batch_dimensions_size(),
289 lhs_transpose, backend_config.lhs_stride());
290 MatrixDescriptor rhs_matrix = make_descriptor(
291 rhs_buffer, rhs_shape, dim_nums.rhs_batch_dimensions_size(),
292 rhs_transpose, backend_config.rhs_stride());
293
294 if (LayoutUtil::Minor(output_shape.layout(), output_row_dim) != 0) {
295 std::swap(lhs_matrix, rhs_matrix);
296 std::swap(output_num_cols, output_num_rows);
297 }
298
299 const MatrixDescriptor output_matrix{
300 output_buffer, se::blas::Transpose::kNoTranspose, output_num_rows,
301 output_num_cols, output_num_rows * output_num_cols};
302 auto best_algorithm = [&]() -> absl::optional<se::blas::AlgorithmType> {
303 if (algorithm) {
304 return *algorithm;
305 }
306 if (backend_config.algorithm_case() ==
307 GemmBackendConfig::ALGORITHM_NOT_SET) {
308 return absl::nullopt;
309 }
310 return backend_config.selected_algorithm();
311 }();
312
313 complex128 alpha = {backend_config.alpha_real(), backend_config.alpha_imag()};
314 double beta = backend_config.beta();
315
316 switch (output_shape.element_type()) {
317 case S32: {
318 if (!best_algorithm) {
319 return InternalError("Only extended GEMM is supported for int32");
320 }
321 CHECK_EQ(alpha.imag(), 0);
322 if (lhs_shape.element_type() == PrimitiveType::S8 &&
323 rhs_shape.element_type() == lhs_shape.element_type()) {
324 return DoGemmWithAlgorithm<int8, int32>(
325 batch_size, lhs_matrix, rhs_matrix, output_matrix,
326 static_cast<int32>(alpha.real()), static_cast<int32>(beta), stream,
327 *best_algorithm,
328 /*output_profile_result=*/profile_result);
329 }
330 return InternalError(
331 "For int32 gemm output only int8 input is supported, got input: %s",
332 primitive_util::LowercasePrimitiveTypeName(lhs_shape.element_type()));
333 }
334 case F16:
335 CHECK_EQ(alpha.imag(), 0);
336 return DoGemm<Eigen::half>(
337 batch_size, lhs_matrix, rhs_matrix, output_matrix,
338 static_cast<Eigen::half>(alpha.real()),
339 static_cast<Eigen::half>(beta), stream, best_algorithm,
340 /*output_profile_result=*/profile_result);
341 case BF16:
342 CHECK_EQ(alpha.imag(), 0);
343 return DoGemm<Eigen::bfloat16>(
344 batch_size, lhs_matrix, rhs_matrix, output_matrix,
345 static_cast<Eigen::bfloat16>(alpha.real()),
346 static_cast<Eigen::bfloat16>(beta), stream, best_algorithm,
347 /*output_profile_result=*/profile_result);
348 case F32:
349 CHECK_EQ(alpha.imag(), 0);
350 return DoGemm<float>(batch_size, lhs_matrix, rhs_matrix, output_matrix,
351 alpha.real(), beta, stream, best_algorithm,
352 /*output_profile_result=*/profile_result);
353 case F64:
354 CHECK_EQ(alpha.imag(), 0);
355 return DoGemm<double>(batch_size, lhs_matrix, rhs_matrix, output_matrix,
356 alpha.real(), beta, stream, best_algorithm,
357 /*output_profile_result=*/profile_result);
358 case C64:
359 return DoGemm<complex64>(batch_size, lhs_matrix, rhs_matrix,
360 output_matrix, static_cast<complex64>(alpha),
361 static_cast<complex64>(beta), stream,
362 best_algorithm,
363 /*output_profile_result=*/profile_result);
364 case C128:
365 return DoGemm<complex128>(
366 batch_size, lhs_matrix, rhs_matrix, output_matrix, alpha,
367 static_cast<complex128>(beta), stream, best_algorithm,
368 /*output_profile_result=*/profile_result);
369 default:
370 return InternalError("Unexpected GEMM datatype: %s",
371 output_shape.ToString());
372 }
373 }
374
375 } // namespace gpu
376 } // namespace xla
377