• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h"
17 
18 #include <functional>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/types/optional.h"
22 #include "tensorflow/compiler/xla/primitive_util.h"
23 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
24 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
25 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
26 #include "tensorflow/compiler/xla/util.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/mutex.h"
29 #include "tensorflow/core/platform/types.h"
30 #include "tensorflow/stream_executor/blas.h"
31 #include "tensorflow/stream_executor/device_memory.h"
32 
33 namespace xla {
34 namespace gpu {
35 
GetGpuGemmConfig(const HloInstruction * gemm)36 GpuGemmConfig GetGpuGemmConfig(const HloInstruction *gemm) {
37   GpuGemmConfig config;
38   config.output_shape = gemm->shape();
39   config.lhs_shape = gemm->operand(0)->shape();
40   config.rhs_shape = gemm->operand(1)->shape();
41   auto backend_config_or = gemm->backend_config<GemmBackendConfig>();
42   config.backend_config = std::move(backend_config_or.ValueOrDie());
43   return config;
44 }
45 
GemmThunk(ThunkInfo thunk_info,GpuGemmConfig config,const BufferAllocation::Slice & lhs_buffer,const BufferAllocation::Slice & rhs_buffer,const BufferAllocation::Slice & output_buffer,bool implements_whole_instruction)46 GemmThunk::GemmThunk(ThunkInfo thunk_info, GpuGemmConfig config,
47                      const BufferAllocation::Slice &lhs_buffer,
48                      const BufferAllocation::Slice &rhs_buffer,
49                      const BufferAllocation::Slice &output_buffer,
50                      bool implements_whole_instruction)
51     : Thunk(Kind::kGemm, thunk_info),
52       config_(std::move(config)),
53       lhs_buffer_(lhs_buffer),
54       rhs_buffer_(rhs_buffer),
55       output_buffer_(output_buffer),
56       implements_whole_instruction_(implements_whole_instruction) {}
57 
ExecuteOnStream(const ExecuteParams & params)58 Status GemmThunk::ExecuteOnStream(const ExecuteParams &params) {
59   auto get_device_address = [&](const BufferAllocation::Slice &slice) {
60     return params.buffer_allocations->GetDeviceAddress(slice);
61   };
62 
63   VLOG(3) << "Running GEMM thunk";
64   se::DeviceMemoryBase lhs_data = get_device_address(lhs_buffer_);
65   se::DeviceMemoryBase rhs_data = get_device_address(rhs_buffer_);
66   se::DeviceMemoryBase output_data = get_device_address(output_buffer_);
67   return RunGemm(config_, lhs_data, rhs_data, output_data, params.stream,
68                  implements_whole_instruction_, profile_index(),
69                  params.profiler);
70 }
71 
72 // This struct contains the metadata of a matrix, e.g., its base address and
73 // dimensions.
74 struct MatrixDescriptor {
75   se::DeviceMemoryBase data;
76   bool transpose;  // Whether this matrix needs to be transposed.
77   int64 num_rows;
78   int64 num_cols;
79 };
80 
81 template <typename Element, typename AlphaType>
DoGemmWithAlgorithm(int64 batch_size,MatrixDescriptor lhs_matrix,MatrixDescriptor rhs_matrix,MatrixDescriptor output_matrix,AlphaType alpha,double beta,se::Stream * stream,absl::optional<se::blas::AlgorithmType> algorithm,se::blas::ProfileResult * output_profile_result)82 static bool DoGemmWithAlgorithm(
83     int64 batch_size, MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
84     MatrixDescriptor output_matrix, AlphaType alpha, double beta,
85     se::Stream *stream, absl::optional<se::blas::AlgorithmType> algorithm,
86     se::blas::ProfileResult *output_profile_result) {
87   DCHECK(!output_matrix.transpose);
88 
89   PrimitiveType type = primitive_util::NativeToPrimitiveType<Element>();
90 
91   // Converts from an XLA PrimitiveType to a blas::ComputationType, which is
92   // used to specify the precision with which matmul computations should be
93   // performed, separately from the precision of the inputs and result.
94   se::blas::ComputationType computation_type;
95   switch (type) {
96     case F16:
97       // Use F32 as computation type for F16 as we currently only implement
98       // the cuDNN pseudo half configuration for half precision.
99       computation_type = se::blas::ComputationType::kF32;
100       break;
101     case F32:
102       computation_type = se::blas::ComputationType::kF32;
103       break;
104     case F64:
105       computation_type = se::blas::ComputationType::kF64;
106       break;
107     case C64:
108       computation_type = se::blas::ComputationType::kComplexF32;
109       break;
110     case C128:
111       computation_type = se::blas::ComputationType::kComplexF64;
112       break;
113     default:
114       return false;
115   }
116 
117   se::DeviceMemory<Element> lhs_data(lhs_matrix.data);
118   se::DeviceMemory<Element> rhs_data(rhs_matrix.data);
119   se::DeviceMemory<Element> output_data(output_matrix.data);
120 
121   auto lhs_transpose = lhs_matrix.transpose ? se::blas::Transpose::kTranspose
122                                             : se::blas::Transpose::kNoTranspose;
123   auto rhs_transpose = rhs_matrix.transpose ? se::blas::Transpose::kTranspose
124                                             : se::blas::Transpose::kNoTranspose;
125   auto k = lhs_matrix.transpose ? lhs_matrix.num_rows : lhs_matrix.num_cols;
126 
127   if (algorithm) {
128     // Autotuning is disabled for batch_size != 1.
129     CHECK_EQ(1, batch_size);
130     return stream
131         ->ThenBlasGemmWithAlgorithm(
132             lhs_transpose, rhs_transpose, output_matrix.num_rows,
133             output_matrix.num_cols,
134             /*size of reduce dim=*/k,
135             /*alpha=*/static_cast<Element>(alpha), lhs_data,
136             /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data,
137             /*leading dim of RHS=*/rhs_matrix.num_rows,
138             /*beta=*/static_cast<Element>(beta), &output_data,
139             /*leading dim of output=*/output_matrix.num_rows, computation_type,
140             *algorithm, output_profile_result)
141         .ok();
142   }
143 
144   if (batch_size != 1) {
145     int64 lhs_stride = lhs_matrix.num_rows * lhs_matrix.num_cols;
146     int64 rhs_stride = rhs_matrix.num_rows * rhs_matrix.num_cols;
147     int64 output_stride = output_matrix.num_rows * output_matrix.num_cols;
148     return stream
149         ->ThenBlasGemmStridedBatched(
150             lhs_transpose, rhs_transpose, output_matrix.num_rows,
151             output_matrix.num_cols, /*size of reduce dim=*/k,
152             /*alpha=*/alpha, lhs_data,
153             /*leading dim of LHS=*/lhs_matrix.num_rows, lhs_stride, rhs_data,
154             /*leading dim of RHS=*/rhs_matrix.num_rows, rhs_stride,
155             /*beta=*/beta, &output_data,
156             /*leading dim of output=*/output_matrix.num_rows, output_stride,
157             batch_size)
158         .ok();
159   }
160 
161   return stream
162       ->ThenBlasGemm(
163           lhs_transpose, rhs_transpose, output_matrix.num_rows,
164           output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha,
165           lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data,
166           /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/beta,
167           &output_data, /*leading dim of output=*/output_matrix.num_rows)
168       .ok();
169 }
170 
RunGemm(const GpuGemmConfig & gemm_config,se::DeviceMemoryBase lhs_buffer,se::DeviceMemoryBase rhs_buffer,se::DeviceMemoryBase output_buffer,se::Stream * stream,bool implements_whole_instruction,absl::optional<int64> profile_index,HloExecutionProfiler * profiler,se::blas::ProfileResult * profile_result,absl::optional<se::blas::AlgorithmType> algorithm)171 Status RunGemm(const GpuGemmConfig &gemm_config,
172                se::DeviceMemoryBase lhs_buffer, se::DeviceMemoryBase rhs_buffer,
173                se::DeviceMemoryBase output_buffer, se::Stream *stream,
174                bool implements_whole_instruction,
175                absl::optional<int64> profile_index,
176                HloExecutionProfiler *profiler,
177                se::blas::ProfileResult *profile_result,
178                absl::optional<se::blas::AlgorithmType> algorithm) {
179   VLOG(2) << "Executing a GemmThunk";
180 
181   const Shape &output_shape = gemm_config.output_shape;
182   const Shape &lhs_shape = gemm_config.lhs_shape;
183   const Shape &rhs_shape = gemm_config.rhs_shape;
184   const GemmBackendConfig &backend_config = gemm_config.backend_config;
185 
186   const DotDimensionNumbers &dim_nums = backend_config.dot_dimension_numbers();
187   CHECK_EQ(dim_nums.lhs_batch_dimensions_size(),
188            dim_nums.rhs_batch_dimensions_size());
189   CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, output_shape.rank());
190 
191   int64 row_dim = dim_nums.lhs_batch_dimensions_size();
192   int64 col_dim = dim_nums.lhs_batch_dimensions_size() + 1;
193 
194   int64 batch_size = backend_config.batch_size();
195 
196   // Check that the batch dims don't cover the last two dims.
197   for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) {
198     CHECK_NE(row_dim, batch_dim);
199     CHECK_NE(col_dim, batch_dim);
200   }
201 
202   // Verify that the non-batch dimensions are minor-most. This is required for
203   // efficient access.
204   for (const auto *shape : {&lhs_shape, &rhs_shape, &output_shape}) {
205     CHECK_LT(shape->layout().minor_to_major(row_dim), 2);
206     CHECK_LT(shape->layout().minor_to_major(col_dim), 2);
207   }
208 
209   int64 output_num_rows = output_shape.dimensions(row_dim);
210   int64 output_num_cols = output_shape.dimensions(col_dim);
211 
212   // BLAS gemm expects the inputs and the output are in column-major order.
213   // Therefore, we need to convert dot between row-major matrices to that
214   // between column-major matrices. The key insight for the conversion is that,
215   // in linear storage, matrix M in column-major order is identical to the
216   // transpose of M in row-major order. In other words,
217   //
218   //   column-major(M) = row-major(M^T).
219   //
220   // Leveraging this insight, we can perform dot between row-major matrices as
221   // follows.
222   //
223   // row-major(C)
224   //   = row-major(A x B) = column-major((A x B)^T) = column-major(B^T x A^T)
225   //   = gemm(column-major(B^T), column-major(A^T))
226   //   = gemm(row-major(B), row-major(A))
227   //
228   // Although we do not modify the content of A and B in linear memory, we
229   // should use the dimensions of B^T and A^T when calling gemm. For example,
230   // the leading dimension of the LHS matrix of gemm is the number of rows in
231   // B^T and thus the number of columns in B.
232   auto make_descriptor = [&](se::DeviceMemoryBase data, const Shape &shape,
233                              bool transpose) -> MatrixDescriptor {
234     bool is_row_major = LayoutUtil::Minor(shape.layout(), row_dim) != 0;
235     bool layout_mismatch = LayoutUtil::Minor(shape.layout(), row_dim) !=
236                            LayoutUtil::Minor(output_shape.layout(), row_dim);
237     return MatrixDescriptor{
238         data, static_cast<bool>(transpose ^ layout_mismatch),
239         shape.dimensions(row_dim + static_cast<int64>(is_row_major)),
240         shape.dimensions(row_dim + static_cast<int64>(!is_row_major))};
241   };
242 
243   MatrixDescriptor lhs_matrix = make_descriptor(
244       lhs_buffer, lhs_shape, dim_nums.lhs_contracting_dimensions(0) == row_dim);
245   MatrixDescriptor rhs_matrix = make_descriptor(
246       rhs_buffer, rhs_shape, dim_nums.rhs_contracting_dimensions(0) == col_dim);
247   std::unique_ptr<ScopedInstructionProfiler> op_profiler =
248       profiler ? profiler->MakeScopedInstructionProfiler(
249                      implements_whole_instruction ? profile_index : -1)
250                : nullptr;
251 
252   if (LayoutUtil::Minor(output_shape.layout(), row_dim) != 0) {
253     std::swap(lhs_matrix, rhs_matrix);
254     std::swap(output_num_cols, output_num_rows);
255   }
256 
257   const MatrixDescriptor output_matrix{output_buffer, /*needs_transpose=*/false,
258                                        output_num_rows, output_num_cols};
259   auto best_algorithm = [&]() -> absl::optional<se::blas::AlgorithmType> {
260     if (algorithm) {
261       return *algorithm;
262     }
263     if (backend_config.algorithm_case() ==
264         GemmBackendConfig::ALGORITHM_NOT_SET) {
265       return absl::nullopt;
266     }
267     return backend_config.selected_algorithm();
268   }();
269 
270   complex128 alpha = {backend_config.alpha_real(), backend_config.alpha_imag()};
271   double beta = backend_config.beta();
272 
273   bool launch_ok = [&]() {
274     switch (output_shape.element_type()) {
275       case F16:
276         CHECK_EQ(alpha.imag(), 0);
277         return DoGemmWithAlgorithm<Eigen::half, double>(
278             batch_size, lhs_matrix, rhs_matrix, output_matrix, alpha.real(),
279             beta, stream, best_algorithm,
280             /*output_profile_result=*/profile_result);
281       case F32:
282         CHECK_EQ(alpha.imag(), 0);
283         return DoGemmWithAlgorithm<float, double>(
284             batch_size, lhs_matrix, rhs_matrix, output_matrix, alpha.real(),
285             beta, stream, best_algorithm,
286             /*output_profile_result=*/profile_result);
287       case F64:
288         CHECK_EQ(alpha.imag(), 0);
289         return DoGemmWithAlgorithm<double, double>(
290             batch_size, lhs_matrix, rhs_matrix, output_matrix, alpha.real(),
291             beta, stream, best_algorithm,
292             /*output_profile_result=*/profile_result);
293       case C64:
294         return DoGemmWithAlgorithm<complex64, complex64>(
295             batch_size, lhs_matrix, rhs_matrix, output_matrix,
296             static_cast<complex64>(alpha), beta, stream, best_algorithm,
297             /*output_profile_result=*/profile_result);
298       case C128:
299         return DoGemmWithAlgorithm<complex128, complex128>(
300             batch_size, lhs_matrix, rhs_matrix, output_matrix, alpha, beta,
301             stream, best_algorithm,
302             /*output_profile_result=*/profile_result);
303       default:
304         return false;
305     }
306   }();
307 
308   if (!launch_ok) {
309     return InternalError("Unable to launch cuBLAS gemm on stream %p", stream);
310   }
311   return Status::OK();
312 }
313 
314 }  // namespace gpu
315 }  // namespace xla
316