• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h"
17 
18 #include <functional>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/types/optional.h"
22 #include "tensorflow/compiler/xla/primitive_util.h"
23 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
24 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
25 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
26 #include "tensorflow/compiler/xla/util.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/mutex.h"
29 #include "tensorflow/core/platform/types.h"
30 #include "tensorflow/stream_executor/blas.h"
31 #include "tensorflow/stream_executor/device_memory.h"
32 
33 namespace xla {
34 namespace gpu {
35 
GetGpuGemmConfig(const HloInstruction * gemm)36 GpuGemmConfig GetGpuGemmConfig(const HloInstruction *gemm) {
37   GpuGemmConfig config;
38   config.output_shape = gemm->shape();
39   config.lhs_shape = gemm->operand(0)->shape();
40   config.rhs_shape = gemm->operand(1)->shape();
41   auto backend_config_or = gemm->backend_config<GemmBackendConfig>();
42   config.backend_config = std::move(backend_config_or.ValueOrDie());
43   return config;
44 }
45 
GemmThunk(ThunkInfo thunk_info,GpuGemmConfig config,const BufferAllocation::Slice & lhs_buffer,const BufferAllocation::Slice & rhs_buffer,const BufferAllocation::Slice & output_buffer,bool implements_whole_instruction)46 GemmThunk::GemmThunk(ThunkInfo thunk_info, GpuGemmConfig config,
47                      const BufferAllocation::Slice &lhs_buffer,
48                      const BufferAllocation::Slice &rhs_buffer,
49                      const BufferAllocation::Slice &output_buffer,
50                      bool implements_whole_instruction)
51     : Thunk(Kind::kGemm, thunk_info),
52       config_(std::move(config)),
53       lhs_buffer_(lhs_buffer),
54       rhs_buffer_(rhs_buffer),
55       output_buffer_(output_buffer),
56       implements_whole_instruction_(implements_whole_instruction) {}
57 
ExecuteOnStream(const ExecuteParams & params)58 Status GemmThunk::ExecuteOnStream(const ExecuteParams &params) {
59   auto get_device_address = [&](const BufferAllocation::Slice &slice) {
60     return params.buffer_allocations->GetDeviceAddress(slice);
61   };
62 
63   VLOG(3) << "Running GEMM thunk";
64   se::DeviceMemoryBase lhs_data = get_device_address(lhs_buffer_);
65   se::DeviceMemoryBase rhs_data = get_device_address(rhs_buffer_);
66   se::DeviceMemoryBase output_data = get_device_address(output_buffer_);
67   return RunGemm(config_, lhs_data, rhs_data, output_data, params.stream,
68                  implements_whole_instruction_, profile_index());
69 }
70 
71 // This struct contains the metadata of a matrix, e.g., its base address and
72 // dimensions.
73 struct MatrixDescriptor {
74   se::DeviceMemoryBase data;
75   se::blas::Transpose transpose;
76   int64 num_rows;
77   int64 num_cols;
78   int64 stride;
79 
reduced_dimxla::gpu::MatrixDescriptor80   int64 reduced_dim() const {
81     return transpose == se::blas::Transpose::kTranspose ? num_rows : num_cols;
82   }
83 
84   template <typename T>
castxla::gpu::MatrixDescriptor85   se::DeviceMemory<T> cast() const {
86     return se::DeviceMemory<T>(data);
87   }
88 };
89 
90 // Converts from an XLA PrimitiveType to a blas::ComputationType, which is
91 // used to specify the precision with which matmul computations should be
92 // performed, separately from the precision of the inputs and result.
ComputationTypeFromPrimitive(PrimitiveType type)93 static absl::optional<se::blas::ComputationType> ComputationTypeFromPrimitive(
94     PrimitiveType type) {
95   switch (type) {
96     case F16:
97     case BF16:
98       return se::blas::ComputationType::kF32;
99     case F32:
100       return se::blas::ComputationType::kF32;
101     case F64:
102       return se::blas::ComputationType::kF64;
103     case C64:
104       return se::blas::ComputationType::kComplexF32;
105     case C128:
106       return se::blas::ComputationType::kComplexF64;
107     case S32:
108       return se::blas::ComputationType::kI32;
109     default:
110       return absl::nullopt;
111   }
112 }
113 
114 template <typename Input, typename Output>
DoGemmWithAlgorithm(int64_t batch_size,MatrixDescriptor lhs,MatrixDescriptor rhs,MatrixDescriptor output_matrix,Output alpha,Output beta,se::Stream * stream,se::blas::AlgorithmType algorithm,se::blas::ProfileResult * output_profile_result)115 static Status DoGemmWithAlgorithm(
116     int64_t batch_size, MatrixDescriptor lhs, MatrixDescriptor rhs,
117     MatrixDescriptor output_matrix, Output alpha, Output beta,
118     se::Stream *stream, se::blas::AlgorithmType algorithm,
119     se::blas::ProfileResult *output_profile_result) {
120   CHECK(output_matrix.transpose == se::blas::Transpose::kNoTranspose);
121   PrimitiveType output_type = primitive_util::NativeToPrimitiveType<Output>();
122   se::blas::ComputationType computation_type =
123       *ComputationTypeFromPrimitive(output_type);
124   se::DeviceMemory<Output> output_data(output_matrix.data);
125 
126   if (batch_size != 1) {
127     return stream->ThenBlasGemmStridedBatchedWithAlgorithm(
128         lhs.transpose, rhs.transpose, output_matrix.num_rows,
129         output_matrix.num_cols,
130         /*size of reduce dim=*/lhs.reduced_dim(),
131         /*alpha=*/alpha, lhs.cast<Input>(), lhs.stride,
132         /*leading dim of LHS=*/lhs.num_rows, rhs.cast<Input>(),
133         /*leading dim of RHS=*/rhs.num_rows, rhs.stride,
134         /*beta=*/beta, &output_data,
135         /*leading dim of output=*/output_matrix.num_rows, output_matrix.stride,
136         batch_size, computation_type, algorithm, output_profile_result);
137   } else {
138     return stream->ThenBlasGemmWithAlgorithm(
139         lhs.transpose, rhs.transpose, output_matrix.num_rows,
140         output_matrix.num_cols,
141         /*size of reduce dim=*/lhs.reduced_dim(),
142         /*alpha=*/alpha, lhs.cast<Input>(),
143         /*lda=*/lhs.num_rows, rhs.cast<Input>(),
144         /*ldb=*/rhs.num_rows,
145         /*beta=*/beta, &output_data,
146         /*ldc=*/output_matrix.num_rows, computation_type, algorithm,
147         output_profile_result);
148   }
149 }
150 
151 template <typename Input>
DoGemm(int64_t batch_size,const MatrixDescriptor & lhs,const MatrixDescriptor & rhs,const MatrixDescriptor & output_matrix,Input alpha,Input beta,se::Stream * stream,absl::optional<se::blas::AlgorithmType> algorithm,se::blas::ProfileResult * output_profile_result)152 static Status DoGemm(int64_t batch_size, const MatrixDescriptor &lhs,
153                      const MatrixDescriptor &rhs,
154                      const MatrixDescriptor &output_matrix, Input alpha,
155                      Input beta, se::Stream *stream,
156                      absl::optional<se::blas::AlgorithmType> algorithm,
157                      se::blas::ProfileResult *output_profile_result) {
158   CHECK(output_matrix.transpose == se::blas::Transpose::kNoTranspose);
159   se::DeviceMemory<Input> output_data(output_matrix.data);
160 
161   if (algorithm) {
162     return DoGemmWithAlgorithm<Input, Input>(batch_size, lhs, rhs,
163                                              output_matrix, alpha, beta, stream,
164                                              *algorithm, output_profile_result);
165   }
166 
167   if (batch_size != 1) {
168     return stream->ThenBlasGemmStridedBatched(
169         lhs.transpose, rhs.transpose, output_matrix.num_rows,
170         output_matrix.num_cols, /*size of reduce dim=*/lhs.reduced_dim(),
171         /*alpha=*/alpha, lhs.cast<Input>(),
172         /*leading dim of LHS=*/lhs.num_rows, lhs.stride, rhs.cast<Input>(),
173         /*leading dim of RHS=*/rhs.num_rows, rhs.stride,
174         /*beta=*/beta, &output_data,
175         /*leading dim of output=*/output_matrix.num_rows, output_matrix.stride,
176         batch_size);
177   }
178   return stream->ThenBlasGemm(
179       lhs.transpose, rhs.transpose, output_matrix.num_rows,
180       output_matrix.num_cols, /*size of reduce dim=*/lhs.reduced_dim(),
181       /*alpha=*/alpha, lhs.cast<Input>(),
182       /*leading dim of LHS=*/lhs.num_rows, rhs.cast<Input>(),
183       /*leading dim of RHS=*/rhs.num_rows,
184       /*beta=*/beta, &output_data,
185       /*leading dim of output=*/output_matrix.num_rows);
186 }
187 
RunGemm(const GpuGemmConfig & gemm_config,se::DeviceMemoryBase lhs_buffer,se::DeviceMemoryBase rhs_buffer,se::DeviceMemoryBase output_buffer,se::Stream * stream,bool implements_whole_instruction,absl::optional<int64> profile_index,se::blas::ProfileResult * profile_result,absl::optional<se::blas::AlgorithmType> algorithm)188 Status RunGemm(const GpuGemmConfig &gemm_config,
189                se::DeviceMemoryBase lhs_buffer, se::DeviceMemoryBase rhs_buffer,
190                se::DeviceMemoryBase output_buffer, se::Stream *stream,
191                bool implements_whole_instruction,
192                absl::optional<int64> profile_index,
193                se::blas::ProfileResult *profile_result,
194                absl::optional<se::blas::AlgorithmType> algorithm) {
195   VLOG(2) << "Executing a GemmThunk";
196 
197   const Shape &output_shape = gemm_config.output_shape;
198   const Shape &lhs_shape = gemm_config.lhs_shape;
199   const Shape &rhs_shape = gemm_config.rhs_shape;
200   const GemmBackendConfig &backend_config = gemm_config.backend_config;
201   const DotDimensionNumbers &dim_nums = backend_config.dot_dimension_numbers();
202   absl::Span<const int64> output_batch_dims =
203       AsInt64Slice((dim_nums.lhs_batch_dimensions_size() >
204                     dim_nums.rhs_batch_dimensions_size())
205                        ? dim_nums.lhs_batch_dimensions()
206                        : dim_nums.rhs_batch_dimensions());
207 
208   int64_t batch_size = backend_config.batch_size();
209   int64_t output_row_dim = output_batch_dims.size();
210   int64_t output_col_dim = output_row_dim + 1;
211 
212   if (backend_config.rhs_stride() && backend_config.lhs_stride()) {
213     CHECK_EQ(dim_nums.lhs_batch_dimensions_size(),
214              dim_nums.rhs_batch_dimensions_size());
215   }
216 
217   int64_t output_num_rows = output_shape.dimensions(output_row_dim);
218   int64_t output_num_cols = output_shape.dimensions(output_col_dim);
219 
220   auto validate_matrix = [&](const Shape &shape, auto batch_dimensions) {
221     int64_t row_dim = batch_dimensions.size();
222     int64_t col_dim = row_dim + 1;
223     CHECK_EQ(row_dim + 2, shape.rank());
224 
225     // Check that the batch dims don't cover the last two dims.
226     for (int64_t batch_dim : batch_dimensions) {
227       CHECK_NE(row_dim, batch_dim);
228       CHECK_NE(col_dim, batch_dim);
229     }
230 
231     // Verify that the non-batch dimensions are minor-most. This is required for
232     // efficient access.
233     CHECK_LT(shape.layout().minor_to_major(row_dim), 2);
234     CHECK_LT(shape.layout().minor_to_major(col_dim), 2);
235   };
236 
237   validate_matrix(lhs_shape, dim_nums.lhs_batch_dimensions());
238   validate_matrix(rhs_shape, dim_nums.rhs_batch_dimensions());
239   validate_matrix(output_shape, output_batch_dims);
240 
241   // BLAS gemm expects the inputs and the output are in column-major order.
242   // Therefore, we need to convert dot between row-major matrices to that
243   // between column-major matrices. The key insight for the conversion is that,
244   // in linear storage, matrix M in column-major order is identical to the
245   // transpose of M in row-major order. In other words,
246   //
247   //   column-major(M) = row-major(M^T).
248   //
249   // Leveraging this insight, we can perform dot between row-major matrices as
250   // follows.
251   //
252   // row-major(C)
253   //   = row-major(A x B) = column-major((A x B)^T) = column-major(B^T x A^T)
254   //   = gemm(column-major(B^T), column-major(A^T))
255   //   = gemm(row-major(B), row-major(A))
256   //
257   // Although we do not modify the content of A and B in linear memory, we
258   // should use the dimensions of B^T and A^T when calling gemm. For example,
259   // the leading dimension of the LHS matrix of gemm is the number of rows in
260   // B^T and thus the number of columns in B.
261   auto make_descriptor = [&](se::DeviceMemoryBase data, const Shape &shape,
262                              int64_t row_dim, bool transpose,
263                              int64_t stride) -> MatrixDescriptor {
264     bool is_row_major = LayoutUtil::Minor(shape.layout(), row_dim) != 0;
265     bool layout_mismatch =
266         LayoutUtil::Minor(shape.layout(), row_dim) !=
267         LayoutUtil::Minor(output_shape.layout(), output_row_dim);
268     int64_t rows =
269         shape.dimensions(row_dim + static_cast<int64_t>(is_row_major));
270     int64_t cols =
271         shape.dimensions(row_dim + static_cast<int64_t>(!is_row_major));
272     if (stride != 0) {
273       CHECK_EQ(stride, rows * cols);
274     }
275     return MatrixDescriptor{data,
276                             transpose ^ layout_mismatch
277                                 ? se::blas::Transpose::kTranspose
278                                 : se::blas::Transpose::kNoTranspose,
279                             rows, cols, stride};
280   };
281 
282   bool lhs_transpose = dim_nums.lhs_contracting_dimensions(0) ==
283                        dim_nums.lhs_batch_dimensions_size();
284   bool rhs_transpose = dim_nums.rhs_contracting_dimensions(0) ==
285                        dim_nums.rhs_batch_dimensions_size() + 1;
286 
287   MatrixDescriptor lhs_matrix = make_descriptor(
288       lhs_buffer, lhs_shape, dim_nums.lhs_batch_dimensions_size(),
289       lhs_transpose, backend_config.lhs_stride());
290   MatrixDescriptor rhs_matrix = make_descriptor(
291       rhs_buffer, rhs_shape, dim_nums.rhs_batch_dimensions_size(),
292       rhs_transpose, backend_config.rhs_stride());
293 
294   if (LayoutUtil::Minor(output_shape.layout(), output_row_dim) != 0) {
295     std::swap(lhs_matrix, rhs_matrix);
296     std::swap(output_num_cols, output_num_rows);
297   }
298 
299   const MatrixDescriptor output_matrix{
300       output_buffer, se::blas::Transpose::kNoTranspose, output_num_rows,
301       output_num_cols, output_num_rows * output_num_cols};
302   auto best_algorithm = [&]() -> absl::optional<se::blas::AlgorithmType> {
303     if (algorithm) {
304       return *algorithm;
305     }
306     if (backend_config.algorithm_case() ==
307         GemmBackendConfig::ALGORITHM_NOT_SET) {
308       return absl::nullopt;
309     }
310     return backend_config.selected_algorithm();
311   }();
312 
313   complex128 alpha = {backend_config.alpha_real(), backend_config.alpha_imag()};
314   double beta = backend_config.beta();
315 
316   switch (output_shape.element_type()) {
317     case S32: {
318       if (!best_algorithm) {
319         return InternalError("Only extended GEMM is supported for int32");
320       }
321       CHECK_EQ(alpha.imag(), 0);
322       if (lhs_shape.element_type() == PrimitiveType::S8 &&
323           rhs_shape.element_type() == lhs_shape.element_type()) {
324         return DoGemmWithAlgorithm<int8, int32>(
325             batch_size, lhs_matrix, rhs_matrix, output_matrix,
326             static_cast<int32>(alpha.real()), static_cast<int32>(beta), stream,
327             *best_algorithm,
328             /*output_profile_result=*/profile_result);
329       }
330       return InternalError(
331           "For int32 gemm output only int8 input is supported, got input: %s",
332           primitive_util::LowercasePrimitiveTypeName(lhs_shape.element_type()));
333     }
334     case F16:
335       CHECK_EQ(alpha.imag(), 0);
336       return DoGemm<Eigen::half>(
337           batch_size, lhs_matrix, rhs_matrix, output_matrix,
338           static_cast<Eigen::half>(alpha.real()),
339           static_cast<Eigen::half>(beta), stream, best_algorithm,
340           /*output_profile_result=*/profile_result);
341     case BF16:
342       CHECK_EQ(alpha.imag(), 0);
343       return DoGemm<Eigen::bfloat16>(
344           batch_size, lhs_matrix, rhs_matrix, output_matrix,
345           static_cast<Eigen::bfloat16>(alpha.real()),
346           static_cast<Eigen::bfloat16>(beta), stream, best_algorithm,
347           /*output_profile_result=*/profile_result);
348     case F32:
349       CHECK_EQ(alpha.imag(), 0);
350       return DoGemm<float>(batch_size, lhs_matrix, rhs_matrix, output_matrix,
351                            alpha.real(), beta, stream, best_algorithm,
352                            /*output_profile_result=*/profile_result);
353     case F64:
354       CHECK_EQ(alpha.imag(), 0);
355       return DoGemm<double>(batch_size, lhs_matrix, rhs_matrix, output_matrix,
356                             alpha.real(), beta, stream, best_algorithm,
357                             /*output_profile_result=*/profile_result);
358     case C64:
359       return DoGemm<complex64>(batch_size, lhs_matrix, rhs_matrix,
360                                output_matrix, static_cast<complex64>(alpha),
361                                static_cast<complex64>(beta), stream,
362                                best_algorithm,
363                                /*output_profile_result=*/profile_result);
364     case C128:
365       return DoGemm<complex128>(
366           batch_size, lhs_matrix, rhs_matrix, output_matrix, alpha,
367           static_cast<complex128>(beta), stream, best_algorithm,
368           /*output_profile_result=*/profile_result);
369     default:
370       return InternalError("Unexpected GEMM datatype: %s",
371                            output_shape.ToString());
372   }
373 }
374 
375 }  // namespace gpu
376 }  // namespace xla
377