• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h"
17 
18 #include <functional>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/types/optional.h"
22 #include "tensorflow/compiler/xla/primitive_util.h"
23 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
24 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
25 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
26 #include "tensorflow/compiler/xla/util.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/mutex.h"
29 #include "tensorflow/core/platform/types.h"
30 #include "tensorflow/stream_executor/blas.h"
31 #include "tensorflow/stream_executor/device_memory.h"
32 
33 namespace xla {
34 namespace gpu {
35 
GemmThunk(const BufferAllocation::Slice & lhs_buffer,const BufferAllocation::Slice & rhs_buffer,const BufferAllocation::Slice & output_buffer,bool implements_whole_instruction,const HloInstruction * hlo_instruction,const GemmBackendConfig & backend_config)36 GemmThunk::GemmThunk(const BufferAllocation::Slice &lhs_buffer,
37                      const BufferAllocation::Slice &rhs_buffer,
38                      const BufferAllocation::Slice &output_buffer,
39                      bool implements_whole_instruction,
40                      const HloInstruction *hlo_instruction,
41                      const GemmBackendConfig &backend_config)
42     : Thunk(Kind::kGemm, hlo_instruction),
43       lhs_buffer_(lhs_buffer),
44       rhs_buffer_(rhs_buffer),
45       output_buffer_(output_buffer),
46       implements_whole_instruction_(implements_whole_instruction),
47       backend_config_(backend_config) {}
48 
ExecuteOnStream(const ExecuteParams & params)49 Status GemmThunk::ExecuteOnStream(const ExecuteParams &params) {
50   auto get_device_address = [&](const BufferAllocation::Slice &slice) {
51     return params.buffer_allocations->GetDeviceAddress(slice);
52   };
53 
54   VLOG(3) << "Running GEMM thunk on instruction: " << hlo_instruction();
55   se::DeviceMemoryBase lhs_data = get_device_address(lhs_buffer_);
56   se::DeviceMemoryBase rhs_data = get_device_address(rhs_buffer_);
57   se::DeviceMemoryBase output_data = get_device_address(output_buffer_);
58   return RunGemm(hlo_instruction(), backend_config_, lhs_data, rhs_data,
59                  output_data, params.stream, implements_whole_instruction_,
60                  params.profiler);
61 }
62 
63 // This struct contains the metadata of a matrix, e.g., its base address and
64 // dimensions.
65 struct MatrixDescriptor {
66   se::DeviceMemoryBase data;
67   bool transpose;  // Whether this matrix needs to be transposed.
68   int64 num_rows;
69   int64 num_cols;
70 };
71 
72 template <typename Element, typename AlphaType>
DoGemmWithAlgorithm(int64 batch_size,MatrixDescriptor lhs_matrix,MatrixDescriptor rhs_matrix,MatrixDescriptor output_matrix,AlphaType alpha,double beta,se::Stream * stream,absl::optional<se::blas::AlgorithmType> algorithm,se::blas::ProfileResult * output_profile_result)73 static bool DoGemmWithAlgorithm(
74     int64 batch_size, MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
75     MatrixDescriptor output_matrix, AlphaType alpha, double beta,
76     se::Stream *stream, absl::optional<se::blas::AlgorithmType> algorithm,
77     se::blas::ProfileResult *output_profile_result) {
78   DCHECK(!output_matrix.transpose);
79 
80   PrimitiveType type = primitive_util::NativeToPrimitiveType<Element>();
81 
82   // Converts from an XLA PrimitiveType to a blas::ComputationType, which is
83   // used to specify the precision with which matmul computations should be
84   // performed, separately from the precision of the inputs and result.
85   se::blas::ComputationType computation_type = [&](PrimitiveType type) {
86     switch (type) {
87       case F16:
88         // Use F32 as computation type for F16 as we currently only implement
89         // the cuDNN pseudo half configuration for half precision.
90         return se::blas::ComputationType::kF32;
91       case F32:
92         return se::blas::ComputationType::kF32;
93       case F64:
94         return se::blas::ComputationType::kF64;
95       case C64:
96         return se::blas::ComputationType::kComplexF32;
97       case C128:
98         return se::blas::ComputationType::kComplexF64;
99       default:
100         LOG(FATAL) << "Unsupported type.";
101     }
102   }(type);
103 
104   se::DeviceMemory<Element> lhs_data(lhs_matrix.data);
105   se::DeviceMemory<Element> rhs_data(rhs_matrix.data);
106   se::DeviceMemory<Element> output_data(output_matrix.data);
107 
108   auto lhs_transpose = lhs_matrix.transpose ? se::blas::Transpose::kTranspose
109                                             : se::blas::Transpose::kNoTranspose;
110   auto rhs_transpose = rhs_matrix.transpose ? se::blas::Transpose::kTranspose
111                                             : se::blas::Transpose::kNoTranspose;
112   auto k = lhs_matrix.transpose ? lhs_matrix.num_rows : lhs_matrix.num_cols;
113 
114   if (algorithm) {
115     // Autotuning is disabled for batch_size != 1.
116     CHECK_EQ(1, batch_size);
117     return stream
118         ->ThenBlasGemmWithAlgorithm(
119             lhs_transpose, rhs_transpose, output_matrix.num_rows,
120             output_matrix.num_cols,
121             /*size of reduce dim=*/k,
122             /*alpha=*/static_cast<Element>(alpha), lhs_data,
123             /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data,
124             /*leading dim of RHS=*/rhs_matrix.num_rows,
125             /*beta=*/static_cast<Element>(beta), &output_data,
126             /*leading dim of output=*/output_matrix.num_rows, computation_type,
127             *algorithm, output_profile_result)
128         .ok();
129   }
130 
131   if (batch_size != 1) {
132     int64 lhs_stride = lhs_matrix.num_rows * lhs_matrix.num_cols;
133     int64 rhs_stride = rhs_matrix.num_rows * rhs_matrix.num_cols;
134     int64 output_stride = output_matrix.num_rows * output_matrix.num_cols;
135     return stream
136         ->ThenBlasGemmStridedBatched(
137             lhs_transpose, rhs_transpose, output_matrix.num_rows,
138             output_matrix.num_cols, /*size of reduce dim=*/k,
139             /*alpha=*/alpha, lhs_data,
140             /*leading dim of LHS=*/lhs_matrix.num_rows, lhs_stride, rhs_data,
141             /*leading dim of RHS=*/rhs_matrix.num_rows, rhs_stride,
142             /*beta=*/beta, &output_data,
143             /*leading dim of output=*/output_matrix.num_rows, output_stride,
144             batch_size)
145         .ok();
146   }
147 
148   return stream
149       ->ThenBlasGemm(
150           lhs_transpose, rhs_transpose, output_matrix.num_rows,
151           output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha,
152           lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data,
153           /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/beta,
154           &output_data, /*leading dim of output=*/output_matrix.num_rows)
155       .ok();
156 }
157 
RunGemm(const HloInstruction * gemm,const GemmBackendConfig & backend_config,se::DeviceMemoryBase lhs_buffer,se::DeviceMemoryBase rhs_buffer,se::DeviceMemoryBase output_buffer,se::Stream * stream,bool implements_whole_instruction,HloExecutionProfiler * profiler,se::blas::ProfileResult * profile_result,absl::optional<se::blas::AlgorithmType> algorithm)158 Status RunGemm(const HloInstruction *gemm,
159                const GemmBackendConfig &backend_config,
160                se::DeviceMemoryBase lhs_buffer, se::DeviceMemoryBase rhs_buffer,
161                se::DeviceMemoryBase output_buffer, se::Stream *stream,
162                bool implements_whole_instruction,
163                HloExecutionProfiler *profiler,
164                se::blas::ProfileResult *profile_result,
165                absl::optional<se::blas::AlgorithmType> algorithm) {
166   VLOG(2) << "Executing a GemmThunk";
167   CHECK(IsCublasGemm(*gemm));
168 
169   const Shape &output_shape = gemm->shape();
170   const HloInstruction *lhs = gemm->operand(0);
171   const HloInstruction *rhs = gemm->operand(1);
172 
173   const Shape &lhs_shape = lhs->shape();
174   const Shape &rhs_shape = rhs->shape();
175 
176   const DotDimensionNumbers &dim_nums = backend_config.dot_dimension_numbers();
177   CHECK_EQ(dim_nums.lhs_batch_dimensions_size(),
178            dim_nums.rhs_batch_dimensions_size());
179   CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, output_shape.rank());
180 
181   int64 row_dim = dim_nums.lhs_batch_dimensions_size();
182   int64 col_dim = dim_nums.lhs_batch_dimensions_size() + 1;
183 
184   int64 batch_size = backend_config.batch_size();
185 
186   // Check that the batch dims don't cover the last two dims.
187   for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) {
188     CHECK_NE(row_dim, batch_dim);
189     CHECK_NE(col_dim, batch_dim);
190   }
191 
192   // Verify that the non-batch dimensions are minor-most. This is required for
193   // efficient access.
194   for (const auto *shape : {&lhs_shape, &rhs_shape, &output_shape}) {
195     CHECK_LT(shape->layout().minor_to_major(row_dim), 2);
196     CHECK_LT(shape->layout().minor_to_major(col_dim), 2);
197   }
198 
199   // BLAS gemm reduces rows of LHS and columns of RHS. The Dot operator between
200   // matrices reduces dimension 1 of LHS and dimension 0 of RHS regardless of
201   // their layout. Therefore, we should treat dimension 0 as row and dimension 1
202   // as column when mapping a matrix Dot to BLAS gemm.
203   int64 output_num_rows = output_shape.dimensions(row_dim);
204   int64 output_num_cols = output_shape.dimensions(col_dim);
205 
206   // BLAS gemm expects the inputs and the output are in column-major order.
207   // Therefore, we need to convert dot between row-major matrices to that
208   // between column-major matrices. The key insight for the conversion is that,
209   // in linear storage, matrix M in column-major order is identical to the
210   // transpose of M in row-major order. In other words,
211   //
212   //   column-major(M) = row-major(M^T).
213   //
214   // Leveraging this insight, we can perform dot between row-major matrices as
215   // follows.
216   //
217   // row-major(C)
218   //   = row-major(A x B) = column-major((A x B)^T) = column-major(B^T x A^T)
219   //   = gemm(column-major(B^T), column-major(A^T))
220   //   = gemm(row-major(B), row-major(A))
221   //
222   // Although we do not modify the content of A and B in linear memory, we
223   // should use the dimensions of B^T and A^T when calling gemm. For example,
224   // the leading dimension of the LHS matrix of gemm is the number of rows in
225   // B^T and thus the number of columns in B.
226   auto make_descriptor = [&](se::DeviceMemoryBase data, const Shape &shape,
227                              bool transpose) -> MatrixDescriptor {
228     bool is_row_major = LayoutUtil::Minor(shape.layout(), row_dim) != 0;
229     bool layout_mismatch = LayoutUtil::Minor(shape.layout(), row_dim) !=
230                            LayoutUtil::Minor(output_shape.layout(), row_dim);
231     return MatrixDescriptor{
232         data, static_cast<bool>(transpose ^ layout_mismatch),
233         shape.dimensions(row_dim + static_cast<int64>(is_row_major)),
234         shape.dimensions(row_dim + static_cast<int64>(!is_row_major))};
235   };
236 
237   MatrixDescriptor lhs_matrix = make_descriptor(
238       lhs_buffer, lhs_shape, dim_nums.lhs_contracting_dimensions(0) == row_dim);
239   MatrixDescriptor rhs_matrix = make_descriptor(
240       rhs_buffer, rhs_shape, dim_nums.rhs_contracting_dimensions(0) == col_dim);
241   std::unique_ptr<ScopedInstructionProfiler> op_profiler =
242       profiler ? profiler->MakeScopedInstructionProfiler(
243                      implements_whole_instruction ? gemm : nullptr)
244                : nullptr;
245 
246   if (LayoutUtil::Minor(output_shape.layout(), row_dim) != 0) {
247     std::swap(lhs_matrix, rhs_matrix);
248     std::swap(output_num_cols, output_num_rows);
249   }
250 
251   const MatrixDescriptor output_matrix{output_buffer, /*needs_transpose=*/false,
252                                        output_num_rows, output_num_cols};
253   auto best_algorithm = [&]() -> absl::optional<se::blas::AlgorithmType> {
254     if (algorithm) {
255       return *algorithm;
256     }
257     if (backend_config.algorithm_case() ==
258         GemmBackendConfig::ALGORITHM_NOT_SET) {
259       return absl::nullopt;
260     }
261     return backend_config.selected_algorithm();
262   }();
263 
264   complex128 alpha = {backend_config.alpha_real(), backend_config.alpha_imag()};
265   double beta = backend_config.beta();
266 
267   bool launch_ok = [&]() {
268     switch (output_shape.element_type()) {
269       case F16:
270         CHECK_EQ(alpha.imag(), 0);
271         return DoGemmWithAlgorithm<Eigen::half, double>(
272             batch_size, lhs_matrix, rhs_matrix, output_matrix, alpha.real(),
273             beta, stream, best_algorithm,
274             /*output_profile_result=*/profile_result);
275       case F32:
276         CHECK_EQ(alpha.imag(), 0);
277         return DoGemmWithAlgorithm<float, double>(
278             batch_size, lhs_matrix, rhs_matrix, output_matrix, alpha.real(),
279             beta, stream, best_algorithm,
280             /*output_profile_result=*/profile_result);
281       case F64:
282         CHECK_EQ(alpha.imag(), 0);
283         return DoGemmWithAlgorithm<double, double>(
284             batch_size, lhs_matrix, rhs_matrix, output_matrix, alpha.real(),
285             beta, stream, best_algorithm,
286             /*output_profile_result=*/profile_result);
287       case C64:
288         return DoGemmWithAlgorithm<complex64, complex64>(
289             batch_size, lhs_matrix, rhs_matrix, output_matrix,
290             static_cast<complex64>(alpha), beta, stream, best_algorithm,
291             /*output_profile_result=*/profile_result);
292       case C128:
293         return DoGemmWithAlgorithm<complex128, complex128>(
294             batch_size, lhs_matrix, rhs_matrix, output_matrix, alpha, beta,
295             stream, best_algorithm,
296             /*output_profile_result=*/profile_result);
297       default:
298         LOG(FATAL) << "Unsupported type.";
299     }
300   }();
301 
302   if (!launch_ok) {
303     return InternalError("Unable to launch cuBLAS gemm on stream %p", stream);
304   }
305   return Status::OK();
306 }
307 
308 }  // namespace gpu
309 }  // namespace xla
310