• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "third_party/gpus/cuda/include/cublasLt.h"
17 #include "third_party/gpus/cuda/include/cublas_v2.h"
18 #include "third_party/gpus/cuda/include/cuda.h"
19 
20 #define SE_CUDA_DATA_HALF CUDA_R_16F
21 
22 #include "tensorflow/stream_executor/cuda/cuda_blas.h"
23 
24 // Both Eigen Half.h and CUDA cuda_fp16.h provide similar typedef for __half. As
25 // such, there are two ways to get the typedef for __half:
26 //
27 // (1) Includes cuda_fp16.h and defines EIGEN_HAS_CUDA_FP16.
28 // (2) Neither includes cuda_fp16.h nor defines EIGEN_HAS_CUDA_FP16.
29 //
30 // Due to issue b/73793421, when the first approach is used and NVCC is used to
31 // compile this file, NVCC will complain duplicated definition for
32 // EIGEN_HAS_CUDA_FP16. On the other hand, when the second approach is used and
33 // clang is used to compile this file, clang will not understand __half
34 // due to missing the definition and macro EIGEN_HAS_CUDA_FP16.
35 //
36 // Because this file may be compiled with CLANG but will never be compiled with
37 // NVCC, we choose the first approach for CUDA < 9.0. For CUDA >= 9.0, we have
38 // to use the second approach because the data member in the __half defined
39 // by CUDA > 9.0 is `__x` while Eigen expects it to be `x`.
40 //
41 // TODO(b/73793421): Remove the following code block to switch to the second
42 // approach when the issue is fixed.
43 #if CUDA_VERSION < 9000
44 #include "third_party/gpus/cuda/include/cuda_fp16.h"
45 #define EIGEN_HAS_CUDA_FP16
46 #endif
47 
48 #include <complex>
49 
50 #include "absl/strings/str_cat.h"
51 #include "absl/strings/str_format.h"
52 #include "third_party/eigen3/Eigen/Core"
53 #include "tensorflow/core/platform/tensor_float_32_utils.h"
54 #include "tensorflow/core/util/env_var.h"
55 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
56 #include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
57 #include "tensorflow/stream_executor/cuda/cuda_helpers.h"
58 #include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
59 #include "tensorflow/stream_executor/cuda/cuda_stream.h"
60 #include "tensorflow/stream_executor/cuda/cuda_timer.h"
61 #include "tensorflow/stream_executor/device_memory.h"
62 #include "tensorflow/stream_executor/lib/env.h"
63 #include "tensorflow/stream_executor/lib/initialize.h"
64 #include "tensorflow/stream_executor/lib/status.h"
65 #include "tensorflow/stream_executor/lib/status_macros.h"
66 #include "tensorflow/stream_executor/platform/logging.h"
67 #include "tensorflow/stream_executor/platform/port.h"
68 #include "tensorflow/stream_executor/plugin_registry.h"
69 #include "tensorflow/stream_executor/scratch_allocator.h"
70 #include "tensorflow/stream_executor/stream_executor.h"
71 
72 namespace stream_executor {
73 namespace gpu {
74 
75 PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuBlasPlugin);
76 
ToString(cublasStatus_t status)77 static std::string ToString(cublasStatus_t status) {
78   switch (status) {
79     case CUBLAS_STATUS_SUCCESS:
80       return "CUBLAS_STATUS_SUCCESS";
81     case CUBLAS_STATUS_NOT_INITIALIZED:
82       return "CUBLAS_STATUS_NOT_INITIALIZED";
83     case CUBLAS_STATUS_ALLOC_FAILED:
84       return "CUBLAS_STATUS_ALLOC_FAILED";
85     case CUBLAS_STATUS_INVALID_VALUE:
86       return "CUBLAS_STATUS_INVALID_VALUE";
87     case CUBLAS_STATUS_ARCH_MISMATCH:
88       return "CUBLAS_STATUS_ARCH_MISMATCH";
89     case CUBLAS_STATUS_MAPPING_ERROR:
90       return "CUBLAS_STATUS_MAPPING_ERROR";
91     case CUBLAS_STATUS_EXECUTION_FAILED:
92       return "CUBLAS_STATUS_EXECUTION_FAILED";
93     case CUBLAS_STATUS_INTERNAL_ERROR:
94       return "CUBLAS_STATUS_INTERNAL_ERROR";
95 #if CUDA_VERSION >= 8000
96     case CUBLAS_STATUS_NOT_SUPPORTED:
97       return "CUBLAS_STATUS_NOT_SUPPORTED";
98     case CUBLAS_STATUS_LICENSE_ERROR:
99       return "CUBLAS_STATUS_LICENSE_ERROR";
100 #endif
101     default:
102       return absl::StrCat("<invalid cublas status: ", status, ">");
103   }
104 }
105 
106 // cuBLAS has interfaces that permit pointers to be passed from either the host
107 // memory space or the device memory space; however, you must instruct it as to
108 // which address space those pointers are in with cublasSetPointerMode.
109 //
110 // This helper sets the cuBLAS pointer mode to a desired value for a cuBLAS call
111 // you are about to perform in a given scope.
112 //
113 // The prior cuBLAS pointer mode is retained and restored when this object goes
114 // out of scope.
115 class ScopedCublasPointerMode {
116  public:
117   // Note that, because the setting of the cublas pointer mode is fallible,
118   // construction of this scoped datatype must be paired with a call to
119   // Init().
120   //
121   // Parameters:
122   //  handle: The cublas library handle to act upon in setting the pointer mode.
ScopedCublasPointerMode(cublasHandle_t handle)123   explicit ScopedCublasPointerMode(cublasHandle_t handle)
124       : handle_(handle), ok_(false) {}
125 
126   // Attempts the switch to the requested scoped pointer mode, new_mode.
127   //
128   // Note that when false is returned, an appropriate error has already been
129   // logged.
Init(cublasPointerMode_t new_mode)130   bool Init(cublasPointerMode_t new_mode) {
131     cublasStatus_t ret = cublasGetPointerMode(handle_, &old_mode_);
132     if (ret != CUBLAS_STATUS_SUCCESS) {
133       LOG(ERROR) << "failed to get old cublas pointer mode: " << ToString(ret);
134       return ok_ = false;
135     }
136 
137     ret = cublasSetPointerMode(handle_, new_mode);
138     if (ret != CUBLAS_STATUS_SUCCESS) {
139       LOG(ERROR) << "failed to set new cublas pointer mode: " << ToString(ret);
140       return ok_ = false;
141     }
142 
143     return ok_ = true;
144   }
145 
146   // Switches back to the prior pointer mode, if the switch operation was
147   // successful in the first place.
~ScopedCublasPointerMode()148   ~ScopedCublasPointerMode() {
149     if (ok_) {
150       cublasStatus_t ret = cublasSetPointerMode(handle_, old_mode_);
151       if (ret != CUBLAS_STATUS_SUCCESS) {
152         LOG(ERROR) << "failed to set former cublas pointer mode: "
153                    << ToString(ret);
154       }
155     }
156   }
157 
158  private:
159   cublasHandle_t handle_;  // Handle to the cuBLAS instance of interest.
160   cublasPointerMode_t old_mode_;  // Prior cuBLAS pointer mode, to be restored.
161   bool ok_;                       // Whether the change was successful.
162 };
163 
164 #if CUDA_VERSION >= 9000
165 // cuBLAS has interfaces that permit computations to use the Volta hardware.
166 // This must be enabled via the cublasGet/SetMathMode APIs.
167 //
168 // This helper sets the cuBLAS math mode to a desired value for a cuBLAS call
169 // you are about to perform in a given scope.
170 //
171 // The prior cuBLAS math mode is retained and restored when this object goes
172 // out of scope.
173 class ScopedCublasMathMode {
174  public:
175   // Note that, because the setting of the cublas math mode is fallible,
176   // construction of this scoped datatype must be paired with a call to
177   // Init().
178   //
179   // Parameters:
180   //  handle: The cublas library handle to act upon in setting the math mode.
ScopedCublasMathMode(cublasHandle_t handle)181   explicit ScopedCublasMathMode(cublasHandle_t handle)
182       : handle_(handle), ok_(false) {}
183 
184   // Attempts the switch to the requested scoped math mode, new_mode.
185   //
186   // Note that when false is returned, an appropriate error has already been
187   // logged.
Init(cublasMath_t new_mode)188   bool Init(cublasMath_t new_mode) {
189     cublasStatus_t ret = cublasGetMathMode(handle_, &old_mode_);
190     if (ret != CUBLAS_STATUS_SUCCESS) {
191       LOG(ERROR) << "failed to get old cublas math mode: " << ToString(ret);
192       return ok_ = false;
193     }
194 
195     ret = cublasSetMathMode(handle_, new_mode);
196     if (ret != CUBLAS_STATUS_SUCCESS) {
197       LOG(ERROR) << "failed to set new cublas math mode: " << ToString(ret);
198       return ok_ = false;
199     }
200     return ok_ = true;
201   }
202 
203   // Switches back to the prior math mode, if the switch operation was
204   // successful in the first place.
~ScopedCublasMathMode()205   ~ScopedCublasMathMode() {
206     if (ok_) {
207       cublasStatus_t ret = cublasSetMathMode(handle_, old_mode_);
208       if (ret != CUBLAS_STATUS_SUCCESS) {
209         LOG(ERROR) << "failed to set former cublas math mode: "
210                    << ToString(ret);
211       }
212     }
213   }
214 
215  private:
216   cublasHandle_t handle_;  // Handle to the cuBLAS instance of interest.
217   cublasMath_t old_mode_;  // Prior cuBLAS math mode, to be restored.
218   bool ok_;                // Whether the change was successful.
219 };
220 #endif  // CUDA_VERSION >= 9000
221 
Init()222 bool CUDABlas::Init() {
223   gpu::ScopedActivateExecutorContext sac{parent_};
224   cublasStatus_t ret = cublasCreate(&blas_);
225   if (ret != CUBLAS_STATUS_SUCCESS) {
226     LOG(ERROR) << "failed to create cublas handle: " << ToString(ret);
227     return false;
228   }
229 
230 #if CUDA_VERSION >= 11000
231   ret = cublasLtCreate(&blasLt_);
232   if (ret != CUBLAS_STATUS_SUCCESS) {
233     LOG(ERROR) << "failed to create cublasLt handle: " << ToString(ret);
234     return false;
235   }
236 #endif  // CUDA_VERSION >= 11000
237 
238   return true;
239 }
240 
CUDABlas(gpu::GpuExecutor * parent)241 CUDABlas::CUDABlas(gpu::GpuExecutor *parent)
242     : parent_(CHECK_NOTNULL(parent)),
243       blas_(nullptr)
244 #if CUDA_VERSION >= 11000
245       ,
246       blasLt_(nullptr)
247 #endif
248 {
249 }
250 
~CUDABlas()251 CUDABlas::~CUDABlas() {
252   if (blas_ != nullptr) {
253     gpu::ScopedActivateExecutorContext sac{parent_};
254     cublasDestroy(blas_);
255   }
256 #if CUDA_VERSION >= 11000
257   if (blasLt_ != nullptr) {
258     gpu::ScopedActivateExecutorContext sac{parent_};
259     cublasLtDestroy(blasLt_);
260   }
261 #endif
262 }
263 
SetStream(Stream * stream)264 bool CUDABlas::SetStream(Stream *stream) {
265   CHECK(stream != nullptr);
266   CHECK(AsGpuStreamValue(stream) != nullptr);
267   CHECK(blas_ != nullptr);
268   gpu::ScopedActivateExecutorContext sac{parent_};
269   cublasStatus_t ret = cublasSetStream(blas_, AsGpuStreamValue(stream));
270   if (ret != CUBLAS_STATUS_SUCCESS) {
271     LOG(ERROR) << "failed to set stream for cuBLAS calls: " << ToString(ret);
272     return false;
273   }
274 
275   return true;
276 }
277 
CUDAStream(Stream * stream)278 cudaStream_t CUDABlas::CUDAStream(Stream *stream) {
279   CHECK(stream != nullptr);
280   CHECK(AsGpuStreamValue(stream) != nullptr);
281   gpu::ScopedActivateExecutorContext sac{parent_};
282   return AsGpuStreamValue(stream);
283 }
284 
285 namespace {
286 
287 // Helper functions transforming blas arguments into cuBLAS arguments.
288 
CUDABlasTranspose(blas::Transpose trans)289 cublasOperation_t CUDABlasTranspose(blas::Transpose trans) {
290   switch (trans) {
291     case blas::Transpose::kNoTranspose:
292       return CUBLAS_OP_N;
293     case blas::Transpose::kTranspose:
294       return CUBLAS_OP_T;
295     case blas::Transpose::kConjugateTranspose:
296       return CUBLAS_OP_C;
297     default:
298       LOG(FATAL) << "Invalid value of blas::Transpose.";
299   }
300 }
301 
CUDABlasUpperLower(blas::UpperLower uplo)302 cublasFillMode_t CUDABlasUpperLower(blas::UpperLower uplo) {
303   switch (uplo) {
304     case blas::UpperLower::kUpper:
305       return CUBLAS_FILL_MODE_UPPER;
306     case blas::UpperLower::kLower:
307       return CUBLAS_FILL_MODE_LOWER;
308     default:
309       LOG(FATAL) << "Invalid value of blas::UpperLower.";
310   }
311 }
312 
CUDABlasDiagonal(blas::Diagonal diag)313 cublasDiagType_t CUDABlasDiagonal(blas::Diagonal diag) {
314   switch (diag) {
315     case blas::Diagonal::kUnit:
316       return CUBLAS_DIAG_UNIT;
317     case blas::Diagonal::kNonUnit:
318       return CUBLAS_DIAG_NON_UNIT;
319     default:
320       LOG(FATAL) << "Invalid value of blas::Diagonal.";
321   }
322 }
323 
CUDABlasSide(blas::Side side)324 cublasSideMode_t CUDABlasSide(blas::Side side) {
325   switch (side) {
326     case blas::Side::kLeft:
327       return CUBLAS_SIDE_LEFT;
328     case blas::Side::kRight:
329       return CUBLAS_SIDE_RIGHT;
330     default:
331       LOG(FATAL) << "Invalid value of blas::Side.";
332   }
333 }
334 
335 // CUDADataType<T>::type translates from a C++ type (e.g. float) to a
336 // cudaDataType_t (e.g. CUDA_R_32F).  CUDAComputationType(ty) translates from a
337 // blas::ComputationType to a cudaDataType_t.
338 //
339 // These are used to build the argument type and computation type args to
340 // cublasGemmEx.
341 template <typename T>
342 struct CUDADataType;
343 
344 template <>
345 struct CUDADataType<Eigen::half> {
346   static constexpr cudaDataType_t type = SE_CUDA_DATA_HALF;
347 };
348 
349 template <>
350 struct CUDADataType<std::complex<Eigen::half>> {
351   static constexpr cudaDataType_t type = CUDA_C_16F;
352 };
353 
354 template <>
355 struct CUDADataType<float> {
356   static constexpr cudaDataType_t type = CUDA_R_32F;
357 };
358 
359 template <>
360 struct CUDADataType<std::complex<float>> {
361   static constexpr cudaDataType_t type = CUDA_C_32F;
362 };
363 
364 template <>
365 struct CUDADataType<double> {
366   static constexpr cudaDataType_t type = CUDA_R_64F;
367 };
368 
369 template <>
370 struct CUDADataType<std::complex<double>> {
371   static constexpr cudaDataType_t type = CUDA_C_64F;
372 };
373 
374 template <>
375 struct CUDADataType<int> {
376   static constexpr cudaDataType_t type = CUDA_R_32I;
377 };
378 
379 template <>
380 struct CUDADataType<int8> {
381   static constexpr cudaDataType_t type = CUDA_R_8I;
382 };
383 
384 template <>
385 struct CUDADataType<std::complex<int8>> {
386   static constexpr cudaDataType_t type = CUDA_C_8I;
387 };
388 
389 template <>
390 struct CUDADataType<uint8> {
391   static constexpr cudaDataType_t type = CUDA_R_8U;
392 };
393 
394 template <>
395 struct CUDADataType<std::complex<uint8>> {
396   static constexpr cudaDataType_t type = CUDA_C_8U;
397 };
398 
CUDAComputationType(blas::ComputationType ty)399 cudaDataType_t CUDAComputationType(blas::ComputationType ty) {
400   switch (ty) {
401     case blas::ComputationType::kF16:
402       return CUDA_R_16F;
403     case blas::ComputationType::kF32:
404       return CUDA_R_32F;
405     case blas::ComputationType::kF64:
406       return CUDA_R_64F;
407     case blas::ComputationType::kI32:
408       return CUDA_R_32I;
409     case blas::ComputationType::kComplexF32:
410       return CUDA_C_32F;
411     case blas::ComputationType::kComplexF64:
412       return CUDA_C_64F;
413     case blas::ComputationType::kTF32AsF32:  // fall-through
414     case blas::ComputationType::kBF16AsF32:
415       // These cases are currently only supported in the blasLt routines, which
416       // use CUBLASComputationType() instead.
417       LOG(FATAL) << "Invalid value of blas::ComputationType.";
418   }
419 }
420 
421 #if CUDA_VERSION >= 11000
CUBLASComputationType(blas::ComputationType ty)422 cublasComputeType_t CUBLASComputationType(blas::ComputationType ty) {
423   switch (ty) {
424     case blas::ComputationType::kF16:
425       return CUBLAS_COMPUTE_16F;
426     case blas::ComputationType::kF32:  // fall-through
427     case blas::ComputationType::kComplexF32:
428       return CUBLAS_COMPUTE_32F;
429     case blas::ComputationType::kF64:  // fall-through
430     case blas::ComputationType::kComplexF64:
431       return CUBLAS_COMPUTE_64F;
432     case blas::ComputationType::kI32:
433       return CUBLAS_COMPUTE_32I;
434     case blas::ComputationType::kTF32AsF32:
435       return CUBLAS_COMPUTE_32F_FAST_TF32;
436     case blas::ComputationType::kBF16AsF32:
437       return CUBLAS_COMPUTE_32F_FAST_16BF;
438   }
439 }
440 #endif  // CUDA_VERSION >= 11000
441 
GetScaleType(blas::DataType data_type,blas::ComputationType compute_type)442 blas::DataType GetScaleType(blas::DataType data_type,
443                             blas::ComputationType compute_type) {
444   bool is_complex = data_type == blas::DataType::kComplexFloat ||
445                     data_type == blas::DataType::kComplexDouble;
446   switch (compute_type) {
447     case blas::ComputationType::kF16:
448       return blas::DataType::kHalf;
449     case blas::ComputationType::kF32:         // fall-through
450     case blas::ComputationType::kComplexF32:  // fall-through
451     case blas::ComputationType::kTF32AsF32:   // fall-through
452     case blas::ComputationType::kBF16AsF32:
453       return is_complex ? blas::DataType::kComplexFloat
454                         : blas::DataType::kFloat;
455     case blas::ComputationType::kF64:  // fall-through
456     case blas::ComputationType::kComplexF64:
457       return is_complex ? blas::DataType::kComplexDouble
458                         : blas::DataType::kDouble;
459     case blas::ComputationType::kI32:
460       return blas::DataType::kInt32;
461   }
462 }
463 
464 #if CUDA_VERSION >= 11000
CUBLASPointerMode(blas::PointerMode pointer_mode)465 cublasLtPointerMode_t CUBLASPointerMode(blas::PointerMode pointer_mode) {
466   switch (pointer_mode) {
467     case blas::PointerMode::kHost:
468       return CUBLASLT_POINTER_MODE_HOST;
469     case blas::PointerMode::kDevice:
470       return CUBLASLT_POINTER_MODE_DEVICE;
471   }
472 }
CUBLASEpilogue(blas::Epilogue epilogue)473 cublasLtEpilogue_t CUBLASEpilogue(blas::Epilogue epilogue) {
474   switch (epilogue) {
475     case blas::Epilogue::kDefault:
476       return CUBLASLT_EPILOGUE_DEFAULT;
477     case blas::Epilogue::kReLU:
478       return CUBLASLT_EPILOGUE_RELU;
479     case blas::Epilogue::kBias:
480       return CUBLASLT_EPILOGUE_BIAS;
481     case blas::Epilogue::kBiasThenReLU:
482       return CUBLASLT_EPILOGUE_RELU_BIAS;
483   }
484 }
485 #endif  // CUDA_VERSION >= 11000
486 
GetCUDADataType(blas::DataType ty)487 cudaDataType_t GetCUDADataType(blas::DataType ty) {
488   switch (ty) {
489     case blas::DataType::kHalf:
490       return CUDA_R_16F;
491 #if CUDA_VERSION >= 11000
492     case blas::DataType::kBF16:
493       return CUDA_R_16BF;
494 #endif
495     case blas::DataType::kFloat:
496       return CUDA_R_32F;
497     case blas::DataType::kDouble:
498       return CUDA_R_64F;
499     case blas::DataType::kInt8:
500       return CUDA_R_8I;
501     case blas::DataType::kInt32:
502       return CUDA_R_32I;
503     case blas::DataType::kComplexFloat:
504       return CUDA_C_32F;
505     case blas::DataType::kComplexDouble:
506       return CUDA_C_64F;
507     default:
508       LOG(FATAL) << "Invalid value of blas::DataType in GetCUDADataType";
509   }
510 }
511 
GetDataTypeSizeBytes(blas::DataType ty)512 int GetDataTypeSizeBytes(blas::DataType ty) {
513   switch (ty) {
514     case blas::DataType::kHalf:
515       return 2;
516     case blas::DataType::kFloat:
517       return 4;
518     case blas::DataType::kDouble:
519       return 8;
520     case blas::DataType::kInt8:
521       return 1;
522     case blas::DataType::kInt32:
523       return 4;
524     case blas::DataType::kComplexFloat:
525       return 8;
526     case blas::DataType::kComplexDouble:
527       return 16;
528     default:
529       LOG(FATAL) << "Invalid value of blas::DataType in GetDataTypeSizeBytes";
530   }
531 }
532 
533 }  // namespace
534 
535 template <typename FuncT, typename... Args>
DoBlasInternalImpl(FuncT cublas_func,Stream * stream,bool pointer_mode_host,cublasMath_t math_type,Args...args)536 port::Status CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
537                                           bool pointer_mode_host,
538                                           cublasMath_t math_type,
539                                           Args... args) {
540   absl::MutexLock lock(&mu_);
541 
542   CHECK(blas_ != nullptr);
543   if (!SetStream(stream)) {
544     return port::InternalError("Failed setting stream");
545   }
546 
547 #if CUDA_VERSION >= 9000
548   ScopedCublasMathMode math_mode{blas_};
549 #if CUBLAS_VER_MAJOR >= 11
550   if (math_type == CUBLAS_TF32_TENSOR_OP_MATH &&
551       tensorflow::tensor_float_32_execution_enabled()) {
552 #else
553   if (math_type == CUBLAS_TENSOR_OP_MATH) {
554 #endif
555     if (!math_mode.Init(math_type)) {
556       return port::InternalError("Failed initializing math mode");
557     }
558   }
559 #endif
560 
561   gpu::ScopedActivateExecutorContext sac{parent_};
562   ScopedCublasPointerMode pointer_mode{blas_};
563   if (!pointer_mode.Init(pointer_mode_host ? CUBLAS_POINTER_MODE_HOST
564                                            : CUBLAS_POINTER_MODE_DEVICE)) {
565     return port::InternalError("Failed setting error mode");
566   }
567   cublasStatus_t ret = cublas_func(blas_, args...);
568   if (ret == CUBLAS_STATUS_SUCCESS) {
569     return port::Status::OK();
570   }
571   return port::InternalError(ToString(ret));
572 }
573 
574 // cublas_func may be overloaded, so we need to figure out which one we really
575 // need to call based on the args. One way to do it is to wrap it in lambda.
576 #define AS_LAMBDA(func)                                                  \
577   [](auto &&... args) -> decltype(                                       \
578                           func(std::forward<decltype(args)>(args)...)) { \
579     return func(std::forward<decltype(args)>(args)...);                  \
580   }
581 
582 bool CUDABlas::DoBlasAsum(Stream *stream, uint64 elem_count,
583                           const DeviceMemory<float> &x, int incx,
584                           DeviceMemory<float> *result) {
585   return DoBlasInternal(cublasSasum, stream, false /* = pointer_mode_host */,
586                         elem_count, GpuMemory(x), incx,
587                         GpuMemoryMutable(result));
588 }
589 
590 bool CUDABlas::DoBlasAsum(Stream *stream, uint64 elem_count,
591                           const DeviceMemory<double> &x, int incx,
592                           DeviceMemory<double> *result) {
593   return DoBlasInternal(cublasDasum, stream, false /* = pointer_mode_host */,
594                         elem_count, GpuMemory(x), incx,
595                         GpuMemoryMutable(result));
596 }
597 
598 bool CUDABlas::DoBlasAsum(Stream *stream, uint64 elem_count,
599                           const DeviceMemory<std::complex<float>> &x, int incx,
600                           DeviceMemory<float> *result) {
601   return DoBlasInternal(cublasScasum, stream, false /* = pointer_mode_host */,
602                         elem_count, GpuComplex(GpuMemory(x)), incx,
603                         GpuMemoryMutable(result));
604 }
605 
606 bool CUDABlas::DoBlasAsum(Stream *stream, uint64 elem_count,
607                           const DeviceMemory<std::complex<double>> &x, int incx,
608                           DeviceMemory<double> *result) {
609   return DoBlasInternal(cublasDzasum, stream, false /* = pointer_mode_host */,
610                         elem_count, GpuComplex(GpuMemory(x)), incx,
611                         GpuMemoryMutable(result));
612 }
613 
614 bool CUDABlas::DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha,
615                           const DeviceMemory<float> &x, int incx,
616                           DeviceMemory<float> *y, int incy) {
617   return DoBlasInternal(cublasSaxpy, stream, true /* = pointer_mode_host */,
618                         elem_count, &alpha, GpuMemory(x), incx,
619                         GpuMemoryMutable(y), incy);
620 }
621 
622 bool CUDABlas::DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha,
623                           const DeviceMemory<double> &x, int incx,
624                           DeviceMemory<double> *y, int incy) {
625   return DoBlasInternal(cublasDaxpy, stream, true /* = pointer_mode_host */,
626                         elem_count, &alpha, GpuMemory(x), incx,
627                         GpuMemoryMutable(y), incy);
628 }
629 
630 bool CUDABlas::DoBlasAxpy(Stream *stream, uint64 elem_count,
631                           std::complex<float> alpha,
632                           const DeviceMemory<std::complex<float>> &x, int incx,
633                           DeviceMemory<std::complex<float>> *y, int incy) {
634   auto cb_alpha = GpuComplexValue(alpha);
635   return DoBlasInternal(cublasCaxpy, stream, true /* = pointer_mode_host */,
636                         elem_count, GpuComplex(&cb_alpha),
637                         GpuComplex(GpuMemory(x)), incx,
638                         GpuComplex(GpuMemoryMutable(y)), incy);
639 }
640 
641 bool CUDABlas::DoBlasAxpy(Stream *stream, uint64 elem_count,
642                           std::complex<double> alpha,
643                           const DeviceMemory<std::complex<double>> &x, int incx,
644                           DeviceMemory<std::complex<double>> *y, int incy) {
645   auto cb_alpha = GpuComplexValue(alpha);
646   return DoBlasInternal(cublasZaxpy, stream, true /* = pointer_mode_host */,
647                         elem_count, GpuComplex(&cb_alpha),
648                         GpuComplex(GpuMemory(x)), incx,
649                         GpuComplex(GpuMemoryMutable(y)), incy);
650 }
651 
652 bool CUDABlas::DoBlasCopy(Stream *stream, uint64 elem_count,
653                           const DeviceMemory<float> &x, int incx,
654                           DeviceMemory<float> *y, int incy) {
655   return DoBlasInternal(cublasScopy, stream, true /* = pointer_mode_host */,
656                         elem_count, GpuMemory(x), incx, GpuMemoryMutable(y),
657                         incy);
658 }
659 
660 bool CUDABlas::DoBlasCopy(Stream *stream, uint64 elem_count,
661                           const DeviceMemory<double> &x, int incx,
662                           DeviceMemory<double> *y, int incy) {
663   return DoBlasInternal(cublasDcopy, stream, true /* = pointer_mode_host */,
664                         elem_count, GpuMemory(x), incx, GpuMemoryMutable(y),
665                         incy);
666 }
667 
668 bool CUDABlas::DoBlasCopy(Stream *stream, uint64 elem_count,
669                           const DeviceMemory<std::complex<float>> &x, int incx,
670                           DeviceMemory<std::complex<float>> *y, int incy) {
671   return DoBlasInternal(cublasCcopy, stream, true /* = pointer_mode_host */,
672                         elem_count, GpuComplex(GpuMemory(x)), incx,
673                         GpuComplex(GpuMemoryMutable(y)), incy);
674 }
675 
676 bool CUDABlas::DoBlasCopy(Stream *stream, uint64 elem_count,
677                           const DeviceMemory<std::complex<double>> &x, int incx,
678                           DeviceMemory<std::complex<double>> *y, int incy) {
679   return DoBlasInternal(cublasZcopy, stream, true /* = pointer_mode_host */,
680                         elem_count, GpuComplex(GpuMemory(x)), incx,
681                         GpuComplex(GpuMemoryMutable(y)), incy);
682 }
683 
684 bool CUDABlas::DoBlasDot(Stream *stream, uint64 elem_count,
685                          const DeviceMemory<float> &x, int incx,
686                          const DeviceMemory<float> &y, int incy,
687                          DeviceMemory<float> *result) {
688   return DoBlasInternal(cublasSdot, stream, false /* = pointer_mode_host */,
689                         elem_count, GpuMemory(x), incx, GpuMemory(y), incy,
690                         GpuMemoryMutable(result));
691 }
692 
693 bool CUDABlas::DoBlasDot(Stream *stream, uint64 elem_count,
694                          const DeviceMemory<double> &x, int incx,
695                          const DeviceMemory<double> &y, int incy,
696                          DeviceMemory<double> *result) {
697   return DoBlasInternal(cublasDdot, stream, false /* = pointer_mode_host */,
698                         elem_count, GpuMemory(x), incx, GpuMemory(y), incy,
699                         GpuMemoryMutable(result));
700 }
701 
702 bool CUDABlas::DoBlasDotc(Stream *stream, uint64 elem_count,
703                           const DeviceMemory<std::complex<float>> &x, int incx,
704                           const DeviceMemory<std::complex<float>> &y, int incy,
705                           DeviceMemory<std::complex<float>> *result) {
706   return DoBlasInternal(cublasCdotc, stream, false /* = pointer_mode_host */,
707                         elem_count, GpuComplex(GpuMemory(x)), incx,
708                         GpuComplex(GpuMemory(y)), incy,
709                         GpuComplex(GpuMemoryMutable(result)));
710 }
711 
712 bool CUDABlas::DoBlasDotc(Stream *stream, uint64 elem_count,
713                           const DeviceMemory<std::complex<double>> &x, int incx,
714                           const DeviceMemory<std::complex<double>> &y, int incy,
715                           DeviceMemory<std::complex<double>> *result) {
716   return DoBlasInternal(cublasZdotc, stream, false /* = pointer_mode_host */,
717                         elem_count, GpuComplex(GpuMemory(x)), incx,
718                         GpuComplex(GpuMemory(y)), incy,
719                         GpuComplex(GpuMemoryMutable(result)));
720 }
721 
722 bool CUDABlas::DoBlasDotu(Stream *stream, uint64 elem_count,
723                           const DeviceMemory<std::complex<float>> &x, int incx,
724                           const DeviceMemory<std::complex<float>> &y, int incy,
725                           DeviceMemory<std::complex<float>> *result) {
726   return DoBlasInternal(cublasCdotu, stream, false /* = pointer_mode_host */,
727                         elem_count, GpuComplex(GpuMemory(x)), incx,
728                         GpuComplex(GpuMemory(y)), incy,
729                         GpuComplex(GpuMemoryMutable(result)));
730 }
731 
732 bool CUDABlas::DoBlasDotu(Stream *stream, uint64 elem_count,
733                           const DeviceMemory<std::complex<double>> &x, int incx,
734                           const DeviceMemory<std::complex<double>> &y, int incy,
735                           DeviceMemory<std::complex<double>> *result) {
736   return DoBlasInternal(cublasZdotu, stream, false /* = pointer_mode_host */,
737                         elem_count, GpuComplex(GpuMemory(x)), incx,
738                         GpuComplex(GpuMemory(y)), incy,
739                         GpuComplex(GpuMemoryMutable(result)));
740 }
741 
742 bool CUDABlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
743                           const DeviceMemory<float> &x, int incx,
744                           DeviceMemory<float> *result) {
745   return DoBlasInternal(cublasSnrm2, stream, false /* = pointer_mode_host */,
746                         elem_count, GpuMemory(x), incx,
747                         GpuMemoryMutable(result));
748 }
749 
750 bool CUDABlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
751                           const DeviceMemory<double> &x, int incx,
752                           DeviceMemory<double> *result) {
753   return DoBlasInternal(cublasDnrm2, stream, false /* = pointer_mode_host */,
754                         elem_count, GpuMemory(x), incx,
755                         GpuMemoryMutable(result));
756 }
757 
758 bool CUDABlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
759                           const DeviceMemory<std::complex<float>> &x, int incx,
760                           DeviceMemory<float> *result) {
761   return DoBlasInternal(cublasScnrm2, stream, false /* = pointer_mode_host */,
762                         elem_count, GpuComplex(GpuMemory(x)), incx,
763                         GpuMemoryMutable(result));
764 }
765 
766 bool CUDABlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
767                           const DeviceMemory<std::complex<double>> &x, int incx,
768                           DeviceMemory<double> *result) {
769   return DoBlasInternal(cublasDznrm2, stream, false /* = pointer_mode_host */,
770                         elem_count, GpuComplex(GpuMemory(x)), incx,
771                         GpuMemoryMutable(result));
772 }
773 
774 bool CUDABlas::DoBlasRot(Stream *stream, uint64 elem_count,
775                          DeviceMemory<float> *x, int incx,
776                          DeviceMemory<float> *y, int incy, float c, float s) {
777   return DoBlasInternal(cublasSrot, stream, true /* = pointer_mode_host */,
778                         elem_count, GpuMemoryMutable(x), incx,
779                         GpuMemoryMutable(y), incy, &c, &s);
780 }
781 
782 bool CUDABlas::DoBlasRot(Stream *stream, uint64 elem_count,
783                          DeviceMemory<double> *x, int incx,
784                          DeviceMemory<double> *y, int incy, double c,
785                          double s) {
786   return DoBlasInternal(cublasDrot, stream, true /* = pointer_mode_host */,
787                         elem_count, GpuMemoryMutable(x), incx,
788                         GpuMemoryMutable(y), incy, &c, &s);
789 }
790 
791 bool CUDABlas::DoBlasRot(Stream *stream, uint64 elem_count,
792                          DeviceMemory<std::complex<float>> *x, int incx,
793                          DeviceMemory<std::complex<float>> *y, int incy,
794                          float c, float s) {
795   return DoBlasInternal(cublasCsrot, stream, true /* = pointer_mode_host */,
796                         elem_count, GpuComplex(GpuMemoryMutable(x)), incx,
797                         GpuComplex(GpuMemoryMutable(y)), incy, &c, &s);
798 }
799 
800 bool CUDABlas::DoBlasRot(Stream *stream, uint64 elem_count,
801                          DeviceMemory<std::complex<double>> *x, int incx,
802                          DeviceMemory<std::complex<double>> *y, int incy,
803                          double c, double s) {
804   return DoBlasInternal(cublasZdrot, stream, true /* = pointer_mode_host */,
805                         elem_count, GpuComplex(GpuMemoryMutable(x)), incx,
806                         GpuComplex(GpuMemoryMutable(y)), incy, &c, &s);
807 }
808 
809 bool CUDABlas::DoBlasRotg(Stream *stream, DeviceMemory<float> *a,
810                           DeviceMemory<float> *b, DeviceMemory<float> *c,
811                           DeviceMemory<float> *s) {
812   return DoBlasInternal(cublasSrotg, stream, false /* = pointer_mode_host */,
813                         GpuMemoryMutable(a), GpuMemoryMutable(b),
814                         GpuMemoryMutable(c), GpuMemoryMutable(s));
815 }
816 
817 bool CUDABlas::DoBlasRotg(Stream *stream, DeviceMemory<double> *a,
818                           DeviceMemory<double> *b, DeviceMemory<double> *c,
819                           DeviceMemory<double> *s) {
820   return DoBlasInternal(cublasDrotg, stream, false /* = pointer_mode_host */,
821                         GpuComplex(GpuMemoryMutable(a)), GpuMemoryMutable(b),
822                         GpuMemoryMutable(c), GpuMemoryMutable(s));
823 }
824 
825 bool CUDABlas::DoBlasRotg(Stream *stream, DeviceMemory<std::complex<float>> *a,
826                           DeviceMemory<std::complex<float>> *b,
827                           DeviceMemory<float> *c,
828                           DeviceMemory<std::complex<float>> *s) {
829   return DoBlasInternal(
830       cublasCrotg, stream, false /* = pointer_mode_host */,
831       GpuComplex(GpuMemoryMutable(a)), GpuComplex(GpuMemoryMutable(b)),
832       GpuComplex(GpuMemoryMutable(c)), GpuComplex(GpuMemoryMutable(s)));
833 }
834 
835 bool CUDABlas::DoBlasRotg(Stream *stream, DeviceMemory<std::complex<double>> *a,
836                           DeviceMemory<std::complex<double>> *b,
837                           DeviceMemory<double> *c,
838                           DeviceMemory<std::complex<double>> *s) {
839   return DoBlasInternal(
840       cublasZrotg, stream, false /* = pointer_mode_host */,
841       GpuComplex(GpuMemoryMutable(a)), GpuComplex(GpuMemoryMutable(b)),
842       GpuComplex(GpuMemoryMutable(c)), GpuComplex(GpuMemoryMutable(s)));
843 }
844 
845 bool CUDABlas::DoBlasRotm(Stream *stream, uint64 elem_count,
846                           DeviceMemory<float> *x, int incx,
847                           DeviceMemory<float> *y, int incy,
848                           const DeviceMemory<float> &param) {
849   return DoBlasInternal(cublasSrotm, stream, false /* = pointer_mode_host */,
850                         elem_count, GpuMemoryMutable(x), incx,
851                         GpuMemoryMutable(y), incy, GpuMemory(param));
852 }
853 
854 bool CUDABlas::DoBlasRotm(Stream *stream, uint64 elem_count,
855                           DeviceMemory<double> *x, int incx,
856                           DeviceMemory<double> *y, int incy,
857                           const DeviceMemory<double> &param) {
858   return DoBlasInternal(cublasDrotm, stream, false /* = pointer_mode_host */,
859                         elem_count, GpuMemoryMutable(x), incx,
860                         GpuMemoryMutable(y), incy, GpuMemory(param));
861 }
862 
863 bool CUDABlas::DoBlasRotmg(Stream *stream, DeviceMemory<float> *d1,
864                            DeviceMemory<float> *d2, DeviceMemory<float> *x1,
865                            const DeviceMemory<float> &y1,
866                            DeviceMemory<float> *param) {
867   return DoBlasInternal(cublasSrotmg, stream, false /* = pointer_mode_host */,
868                         GpuMemoryMutable(d1), GpuMemoryMutable(d2),
869                         GpuMemoryMutable(x1), GpuMemory(y1),
870                         GpuMemoryMutable(param));
871 }
872 
873 bool CUDABlas::DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1,
874                            DeviceMemory<double> *d2, DeviceMemory<double> *x1,
875                            const DeviceMemory<double> &y1,
876                            DeviceMemory<double> *param) {
877   return DoBlasInternal(cublasDrotmg, stream, false /* = pointer_mode_host */,
878                         GpuMemoryMutable(d1), GpuMemoryMutable(d2),
879                         GpuMemoryMutable(x1), GpuMemory(y1),
880                         GpuMemoryMutable(param));
881 }
882 
883 bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
884                           DeviceMemory<float> *x, int incx) {
885   return DoBlasInternal(cublasSscal, stream, true /* = pointer_mode_host */,
886                         elem_count, &alpha, GpuMemoryMutable(x), incx);
887 }
888 
889 bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
890                           DeviceMemory<double> *x, int incx) {
891   return DoBlasInternal(cublasDscal, stream, true /* = pointer_mode_host */,
892                         elem_count, &alpha, GpuMemoryMutable(x), incx);
893 }
894 
895 bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
896                           DeviceMemory<std::complex<float>> *x, int incx) {
897   return DoBlasInternal(cublasCsscal, stream, true /* = pointer_mode_host */,
898                         elem_count, &alpha, GpuComplex(GpuMemoryMutable(x)),
899                         incx);
900 }
901 
902 bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
903                           DeviceMemory<std::complex<double>> *x, int incx) {
904   return DoBlasInternal(cublasZdscal, stream, true /* = pointer_mode_host */,
905                         elem_count, &alpha, GpuComplex(GpuMemoryMutable(x)),
906                         incx);
907 }
908 
909 bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count,
910                           std::complex<float> alpha,
911                           DeviceMemory<std::complex<float>> *x, int incx) {
912   auto cb_alpha = GpuComplexValue(alpha);
913   return DoBlasInternal(cublasCscal, stream, true /* = pointer_mode_host */,
914                         elem_count, GpuComplex(&cb_alpha),
915                         GpuComplex(GpuMemoryMutable(x)), incx);
916 }
917 
918 bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count,
919                           std::complex<double> alpha,
920                           DeviceMemory<std::complex<double>> *x, int incx) {
921   auto cb_alpha = GpuComplexValue(alpha);
922   return DoBlasInternal(cublasZscal, stream, true /* = pointer_mode_host */,
923                         elem_count, GpuComplex(&cb_alpha),
924                         GpuComplex(GpuMemoryMutable(x)), incx);
925 }
926 
927 bool CUDABlas::DoBlasSwap(Stream *stream, uint64 elem_count,
928                           DeviceMemory<float> *x, int incx,
929                           DeviceMemory<float> *y, int incy) {
930   return DoBlasInternal(cublasSswap, stream, true /* = pointer_mode_host */,
931                         elem_count, GpuMemoryMutable(x), incx,
932                         GpuMemoryMutable(y), incy);
933 }
934 
935 bool CUDABlas::DoBlasSwap(Stream *stream, uint64 elem_count,
936                           DeviceMemory<double> *x, int incx,
937                           DeviceMemory<double> *y, int incy) {
938   return DoBlasInternal(cublasDswap, stream, true /* = pointer_mode_host */,
939                         elem_count, GpuMemoryMutable(x), incx,
940                         GpuMemoryMutable(y), incy);
941 }
942 
943 bool CUDABlas::DoBlasSwap(Stream *stream, uint64 elem_count,
944                           DeviceMemory<std::complex<float>> *x, int incx,
945                           DeviceMemory<std::complex<float>> *y, int incy) {
946   return DoBlasInternal(cublasCswap, stream, true /* = pointer_mode_host */,
947                         elem_count, GpuComplex(GpuMemoryMutable(x)), incx,
948                         GpuComplex(GpuMemoryMutable(y)), incy);
949 }
950 
951 bool CUDABlas::DoBlasSwap(Stream *stream, uint64 elem_count,
952                           DeviceMemory<std::complex<double>> *x, int incx,
953                           DeviceMemory<std::complex<double>> *y, int incy) {
954   return DoBlasInternal(cublasZswap, stream, true /* = pointer_mode_host */,
955                         elem_count, GpuComplex(GpuMemoryMutable(x)), incx,
956                         GpuComplex(GpuMemoryMutable(y)), incy);
957 }
958 
959 bool CUDABlas::DoBlasIamax(Stream *stream, uint64 elem_count,
960                            const DeviceMemory<float> &x, int incx,
961                            DeviceMemory<int> *result) {
962   return DoBlasInternal(cublasIsamax, stream, false /* = pointer_mode_host */,
963                         elem_count, GpuMemory(x), incx,
964                         GpuMemoryMutable(result));
965 }
966 
967 bool CUDABlas::DoBlasIamax(Stream *stream, uint64 elem_count,
968                            const DeviceMemory<double> &x, int incx,
969                            DeviceMemory<int> *result) {
970   return DoBlasInternal(cublasIdamax, stream, false /* = pointer_mode_host */,
971                         elem_count, GpuMemory(x), incx,
972                         GpuMemoryMutable(result));
973 }
974 
975 bool CUDABlas::DoBlasIamax(Stream *stream, uint64 elem_count,
976                            const DeviceMemory<std::complex<float>> &x, int incx,
977                            DeviceMemory<int> *result) {
978   return DoBlasInternal(cublasIcamax, stream, false /* = pointer_mode_host */,
979                         elem_count, GpuComplex(GpuMemory(x)), incx,
980                         GpuMemoryMutable(result));
981 }
982 
983 bool CUDABlas::DoBlasIamax(Stream *stream, uint64 elem_count,
984                            const DeviceMemory<std::complex<double>> &x,
985                            int incx, DeviceMemory<int> *result) {
986   return DoBlasInternal(cublasIzamax, stream, false /* = pointer_mode_host */,
987                         elem_count, GpuComplex(GpuMemory(x)), incx,
988                         GpuMemoryMutable(result));
989 }
990 
991 bool CUDABlas::DoBlasIamin(Stream *stream, uint64 elem_count,
992                            const DeviceMemory<float> &x, int incx,
993                            DeviceMemory<int> *result) {
994   return DoBlasInternal(cublasIsamin, stream, false /* = pointer_mode_host */,
995                         elem_count, GpuComplex(GpuMemory(x)), incx,
996                         GpuMemoryMutable(result));
997 }
998 
999 bool CUDABlas::DoBlasIamin(Stream *stream, uint64 elem_count,
1000                            const DeviceMemory<double> &x, int incx,
1001                            DeviceMemory<int> *result) {
1002   return DoBlasInternal(cublasIdamin, stream, false /* = pointer_mode_host */,
1003                         elem_count, GpuComplex(GpuMemory(x)), incx,
1004                         GpuMemoryMutable(result));
1005 }
1006 
1007 bool CUDABlas::DoBlasIamin(Stream *stream, uint64 elem_count,
1008                            const DeviceMemory<std::complex<float>> &x, int incx,
1009                            DeviceMemory<int> *result) {
1010   return DoBlasInternal(cublasIcamin, stream, false /* = pointer_mode_host */,
1011                         elem_count, GpuComplex(GpuMemory(x)), incx,
1012                         GpuMemoryMutable(result));
1013 }
1014 
1015 bool CUDABlas::DoBlasIamin(Stream *stream, uint64 elem_count,
1016                            const DeviceMemory<std::complex<double>> &x,
1017                            int incx, DeviceMemory<int> *result) {
1018   return DoBlasInternal(cublasIzamin, stream, false /* = pointer_mode_host */,
1019                         elem_count, GpuComplex(GpuMemory(x)), incx,
1020                         GpuMemoryMutable(result));
1021 }
1022 
1023 bool CUDABlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
1024                           uint64 n, uint64 kl, uint64 ku, float alpha,
1025                           const DeviceMemory<float> &a, int lda,
1026                           const DeviceMemory<float> &x, int incx, float beta,
1027                           DeviceMemory<float> *y, int incy) {
1028   return DoBlasInternal(cublasSgbmv, stream, true /* = pointer_mode_host */,
1029                         CUDABlasTranspose(trans), m, n, kl, ku, &alpha,
1030                         GpuMemory(a), lda, GpuMemory(x), incx, &beta,
1031                         GpuMemoryMutable(y), incy);
1032 }
1033 
1034 bool CUDABlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
1035                           uint64 n, uint64 kl, uint64 ku, double alpha,
1036                           const DeviceMemory<double> &a, int lda,
1037                           const DeviceMemory<double> &x, int incx, double beta,
1038                           DeviceMemory<double> *y, int incy) {
1039   return DoBlasInternal(cublasDgbmv, stream, true /* = pointer_mode_host */,
1040                         CUDABlasTranspose(trans), m, n, kl, ku, &alpha,
1041                         GpuMemory(a), lda, GpuMemory(x), incx, &beta,
1042                         GpuMemoryMutable(y), incy);
1043 }
1044 
1045 bool CUDABlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
1046                           uint64 n, uint64 kl, uint64 ku,
1047                           std::complex<float> alpha,
1048                           const DeviceMemory<std::complex<float>> &a, int lda,
1049                           const DeviceMemory<std::complex<float>> &x, int incx,
1050                           std::complex<float> beta,
1051                           DeviceMemory<std::complex<float>> *y, int incy) {
1052   auto cb_alpha = GpuComplexValue(alpha);
1053   auto cb_beta = GpuComplexValue(beta);
1054   return DoBlasInternal(cublasCgbmv, stream, true /* = pointer_mode_host */,
1055                         CUDABlasTranspose(trans), m, n, kl, ku,
1056                         GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
1057                         GpuComplex(GpuMemory(x)), incx, GpuComplex(&cb_beta),
1058                         GpuComplex(GpuMemoryMutable(y)), incy);
1059 }
1060 
1061 bool CUDABlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
1062                           uint64 n, uint64 kl, uint64 ku,
1063                           std::complex<double> alpha,
1064                           const DeviceMemory<std::complex<double>> &a, int lda,
1065                           const DeviceMemory<std::complex<double>> &x, int incx,
1066                           std::complex<double> beta,
1067                           DeviceMemory<std::complex<double>> *y, int incy) {
1068   auto cb_alpha = GpuComplexValue(alpha);
1069   auto cb_beta = GpuComplexValue(beta);
1070   return DoBlasInternal(cublasZgbmv, stream, true /* = pointer_mode_host */,
1071                         CUDABlasTranspose(trans), m, n, kl, ku,
1072                         GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
1073                         GpuComplex(GpuMemory(x)), incx, GpuComplex(&cb_beta),
1074                         GpuComplex(GpuMemoryMutable(y)), incy);
1075 }
1076 
1077 bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
1078                           uint64 n, float alpha, const DeviceMemory<float> &a,
1079                           int lda, const DeviceMemory<float> &x, int incx,
1080                           float beta, DeviceMemory<float> *y, int incy) {
1081   return DoBlasInternal(cublasSgemv, stream, true /* = pointer_mode_host */,
1082                         CUDABlasTranspose(trans), m, n, &alpha, GpuMemory(a),
1083                         lda, GpuMemory(x), incx, &beta, GpuMemoryMutable(y),
1084                         incy);
1085 }
1086 
1087 bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
1088                           uint64 n, double alpha, const DeviceMemory<double> &a,
1089                           int lda, const DeviceMemory<double> &x, int incx,
1090                           double beta, DeviceMemory<double> *y, int incy) {
1091   return DoBlasInternal(cublasDgemv, stream, true /* = pointer_mode_host */,
1092                         CUDABlasTranspose(trans), m, n, &alpha, GpuMemory(a),
1093                         lda, GpuMemory(x), incx, &beta, GpuMemoryMutable(y),
1094                         incy);
1095 }
1096 
1097 bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
1098                           uint64 n, std::complex<float> alpha,
1099                           const DeviceMemory<std::complex<float>> &a, int lda,
1100                           const DeviceMemory<std::complex<float>> &x, int incx,
1101                           std::complex<float> beta,
1102                           DeviceMemory<std::complex<float>> *y, int incy) {
1103   auto cb_alpha = GpuComplexValue(alpha);
1104   auto cb_beta = GpuComplexValue(beta);
1105   return DoBlasInternal(cublasCgemv, stream, true /* = pointer_mode_host */,
1106                         CUDABlasTranspose(trans), m, n, GpuComplex(&cb_alpha),
1107                         GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
1108                         incx, GpuComplex(&cb_beta),
1109                         GpuComplex(GpuMemoryMutable(y)), incy);
1110 }
1111 
1112 bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
1113                           uint64 n, std::complex<double> alpha,
1114                           const DeviceMemory<std::complex<double>> &a, int lda,
1115                           const DeviceMemory<std::complex<double>> &x, int incx,
1116                           std::complex<double> beta,
1117                           DeviceMemory<std::complex<double>> *y, int incy) {
1118   auto cb_alpha = GpuComplexValue(alpha);
1119   auto cb_beta = GpuComplexValue(beta);
1120   return DoBlasInternal(cublasZgemv, stream, true /* = pointer_mode_host */,
1121                         CUDABlasTranspose(trans), m, n, GpuComplex(&cb_alpha),
1122                         GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
1123                         incx, GpuComplex(&cb_beta),
1124                         GpuComplex(GpuMemoryMutable(y)), incy);
1125 }
1126 
1127 bool CUDABlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha,
1128                          const DeviceMemory<float> &x, int incx,
1129                          const DeviceMemory<float> &y, int incy,
1130                          DeviceMemory<float> *a, int lda) {
1131   return DoBlasInternal(cublasSger, stream, true /* = pointer_mode_host */, m,
1132                         n, &alpha, GpuMemory(x), incx, GpuMemory(y), incy,
1133                         GpuMemoryMutable(a), lda);
1134 }
1135 
1136 bool CUDABlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha,
1137                          const DeviceMemory<double> &x, int incx,
1138                          const DeviceMemory<double> &y, int incy,
1139                          DeviceMemory<double> *a, int lda) {
1140   return DoBlasInternal(cublasDger, stream, true /* = pointer_mode_host */, m,
1141                         n, &alpha, GpuMemory(x), incx, GpuMemory(y), incy,
1142                         GpuMemoryMutable(a), lda);
1143 }
1144 
1145 bool CUDABlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n,
1146                           std::complex<float> alpha,
1147                           const DeviceMemory<std::complex<float>> &x, int incx,
1148                           const DeviceMemory<std::complex<float>> &y, int incy,
1149                           DeviceMemory<std::complex<float>> *a, int lda) {
1150   auto cb_alpha = GpuComplexValue(alpha);
1151   return DoBlasInternal(cublasCgerc, stream, true /* = pointer_mode_host */, m,
1152                         n, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(x)),
1153                         incx, GpuComplex(GpuMemory(y)), incy,
1154                         GpuComplex(GpuMemoryMutable(a)), lda);
1155 }
1156 
1157 bool CUDABlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n,
1158                           std::complex<double> alpha,
1159                           const DeviceMemory<std::complex<double>> &x, int incx,
1160                           const DeviceMemory<std::complex<double>> &y, int incy,
1161                           DeviceMemory<std::complex<double>> *a, int lda) {
1162   auto cb_alpha = GpuComplexValue(alpha);
1163   return DoBlasInternal(cublasZgerc, stream, true /* = pointer_mode_host */, m,
1164                         n, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(x)),
1165                         incx, GpuComplex(GpuMemory(y)), incy,
1166                         GpuComplex(GpuMemoryMutable(a)), lda);
1167 }
1168 
1169 bool CUDABlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n,
1170                           std::complex<float> alpha,
1171                           const DeviceMemory<std::complex<float>> &x, int incx,
1172                           const DeviceMemory<std::complex<float>> &y, int incy,
1173                           DeviceMemory<std::complex<float>> *a, int lda) {
1174   auto cb_alpha = GpuComplexValue(alpha);
1175   return DoBlasInternal(cublasCgeru, stream, true /* = pointer_mode_host */, m,
1176                         n, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(x)),
1177                         incx, GpuComplex(GpuMemory(y)), incy,
1178                         GpuComplex(GpuMemoryMutable(a)), lda);
1179 }
1180 
1181 bool CUDABlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n,
1182                           std::complex<double> alpha,
1183                           const DeviceMemory<std::complex<double>> &x, int incx,
1184                           const DeviceMemory<std::complex<double>> &y, int incy,
1185                           DeviceMemory<std::complex<double>> *a, int lda) {
1186   auto cb_alpha = GpuComplexValue(alpha);
1187   return DoBlasInternal(cublasZgeru, stream, true /* = pointer_mode_host */, m,
1188                         n, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(x)),
1189                         incx, GpuComplex(GpuMemory(y)), incy,
1190                         GpuComplex(GpuMemoryMutable(a)), lda);
1191 }
1192 
1193 bool CUDABlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1194                           uint64 k, std::complex<float> alpha,
1195                           const DeviceMemory<std::complex<float>> &a, int lda,
1196                           const DeviceMemory<std::complex<float>> &x, int incx,
1197                           std::complex<float> beta,
1198                           DeviceMemory<std::complex<float>> *y, int incy) {
1199   auto cb_alpha = GpuComplexValue(alpha);
1200   auto cb_beta = GpuComplexValue(beta);
1201   return DoBlasInternal(cublasChbmv, stream, true /* = pointer_mode_host */,
1202                         CUDABlasUpperLower(uplo), n, k, GpuComplex(&cb_alpha),
1203                         GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
1204                         incx, GpuComplex(&cb_beta),
1205                         GpuComplex(GpuMemoryMutable(y)), incy);
1206 }
1207 
1208 bool CUDABlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1209                           uint64 k, std::complex<double> alpha,
1210                           const DeviceMemory<std::complex<double>> &a, int lda,
1211                           const DeviceMemory<std::complex<double>> &x, int incx,
1212                           std::complex<double> beta,
1213                           DeviceMemory<std::complex<double>> *y, int incy) {
1214   auto cb_alpha = GpuComplexValue(alpha);
1215   auto cb_beta = GpuComplexValue(beta);
1216   return DoBlasInternal(cublasZhbmv, stream, true /* = pointer_mode_host */,
1217                         CUDABlasUpperLower(uplo), n, k, GpuComplex(&cb_alpha),
1218                         GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
1219                         incx, GpuComplex(&cb_beta),
1220                         GpuComplex(GpuMemoryMutable(y)), incy);
1221 }
1222 
1223 bool CUDABlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
1224                           std::complex<float> alpha,
1225                           const DeviceMemory<std::complex<float>> &a, int lda,
1226                           const DeviceMemory<std::complex<float>> &x, int incx,
1227                           std::complex<float> beta,
1228                           DeviceMemory<std::complex<float>> *y, int incy) {
1229   auto cb_alpha = GpuComplexValue(alpha);
1230   auto cb_beta = GpuComplexValue(beta);
1231   return DoBlasInternal(cublasChemv, stream, true /* = pointer_mode_host */,
1232                         CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
1233                         GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
1234                         incx, GpuComplex(&cb_beta),
1235                         GpuComplex(GpuMemoryMutable(y)), incy);
1236 }
1237 
1238 bool CUDABlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
1239                           std::complex<double> alpha,
1240                           const DeviceMemory<std::complex<double>> &a, int lda,
1241                           const DeviceMemory<std::complex<double>> &x, int incx,
1242                           std::complex<double> beta,
1243                           DeviceMemory<std::complex<double>> *y, int incy) {
1244   auto cb_alpha = GpuComplexValue(alpha);
1245   auto cb_beta = GpuComplexValue(beta);
1246   return DoBlasInternal(cublasZhemv, stream, true /* = pointer_mode_host */,
1247                         CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
1248                         GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
1249                         incx, GpuComplex(&cb_beta),
1250                         GpuComplex(GpuMemoryMutable(y)), incy);
1251 }
1252 
1253 bool CUDABlas::DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
1254                          float alpha,
1255                          const DeviceMemory<std::complex<float>> &x, int incx,
1256                          DeviceMemory<std::complex<float>> *a, int lda) {
1257   return DoBlasInternal(cublasCher, stream, true /* = pointer_mode_host */,
1258                         CUDABlasUpperLower(uplo), n, &alpha,
1259                         GpuComplex(GpuMemory(x)), incx,
1260                         GpuComplex(GpuMemoryMutable(a)), lda);
1261 }
1262 
1263 bool CUDABlas::DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
1264                          double alpha,
1265                          const DeviceMemory<std::complex<double>> &x, int incx,
1266                          DeviceMemory<std::complex<double>> *a, int lda) {
1267   return DoBlasInternal(cublasZher, stream, true /* = pointer_mode_host */,
1268                         CUDABlasUpperLower(uplo), n, &alpha,
1269                         GpuComplex(GpuMemory(x)), incx,
1270                         GpuComplex(GpuMemoryMutable(a)), lda);
1271 }
1272 
1273 bool CUDABlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
1274                           std::complex<float> alpha,
1275                           const DeviceMemory<std::complex<float>> &x, int incx,
1276                           const DeviceMemory<std::complex<float>> &y, int incy,
1277                           DeviceMemory<std::complex<float>> *a, int lda) {
1278   auto cb_alpha = GpuComplexValue(alpha);
1279   return DoBlasInternal(cublasCher2, stream, true /* = pointer_mode_host */,
1280                         CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
1281                         GpuComplex(GpuMemory(x)), incx,
1282                         GpuComplex(GpuMemory(y)), incy,
1283                         GpuComplex(GpuMemoryMutable(a)), lda);
1284 }
1285 
1286 bool CUDABlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
1287                           std::complex<double> alpha,
1288                           const DeviceMemory<std::complex<double>> &x, int incx,
1289                           const DeviceMemory<std::complex<double>> &y, int incy,
1290                           DeviceMemory<std::complex<double>> *a, int lda) {
1291   auto cb_alpha = GpuComplexValue(alpha);
1292   return DoBlasInternal(cublasZher2, stream, true /* = pointer_mode_host */,
1293                         CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
1294                         GpuComplex(GpuMemory(x)), incx,
1295                         GpuComplex(GpuMemory(y)), incy,
1296                         GpuComplex(GpuMemoryMutable(a)), lda);
1297 }
1298 
1299 bool CUDABlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1300                           std::complex<float> alpha,
1301                           const DeviceMemory<std::complex<float>> &ap,
1302                           const DeviceMemory<std::complex<float>> &x, int incx,
1303                           std::complex<float> beta,
1304                           DeviceMemory<std::complex<float>> *y, int incy) {
1305   auto cb_alpha = GpuComplexValue(alpha);
1306   auto cb_beta = GpuComplexValue(beta);
1307   return DoBlasInternal(cublasChpmv, stream, true /* = pointer_mode_host */,
1308                         CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
1309                         GpuComplex(GpuMemory(ap)), GpuComplex(GpuMemory(x)),
1310                         incx, GpuComplex(&cb_beta),
1311                         GpuComplex(GpuMemoryMutable(y)), incy);
1312 }
1313 
1314 bool CUDABlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1315                           std::complex<double> alpha,
1316                           const DeviceMemory<std::complex<double>> &ap,
1317                           const DeviceMemory<std::complex<double>> &x, int incx,
1318                           std::complex<double> beta,
1319                           DeviceMemory<std::complex<double>> *y, int incy) {
1320   auto cb_alpha = GpuComplexValue(alpha);
1321   auto cb_beta = GpuComplexValue(beta);
1322   return DoBlasInternal(cublasZhpmv, stream, true /* = pointer_mode_host */,
1323                         CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
1324                         GpuComplex(GpuMemory(ap)), GpuComplex(GpuMemory(x)),
1325                         incx, GpuComplex(&cb_beta),
1326                         GpuComplex(GpuMemoryMutable(y)), incy);
1327 }
1328 
1329 bool CUDABlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
1330                          float alpha,
1331                          const DeviceMemory<std::complex<float>> &x, int incx,
1332                          DeviceMemory<std::complex<float>> *ap) {
1333   return DoBlasInternal(cublasChpr, stream, true /* = pointer_mode_host */,
1334                         CUDABlasUpperLower(uplo), n, &alpha,
1335                         GpuComplex(GpuMemory(x)), incx,
1336                         GpuComplex(GpuMemoryMutable(ap)));
1337 }
1338 
1339 bool CUDABlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
1340                          double alpha,
1341                          const DeviceMemory<std::complex<double>> &x, int incx,
1342                          DeviceMemory<std::complex<double>> *ap) {
1343   return DoBlasInternal(cublasZhpr, stream, true /* = pointer_mode_host */,
1344                         CUDABlasUpperLower(uplo), n, &alpha,
1345                         GpuComplex(GpuMemory(x)), incx,
1346                         GpuComplex(GpuMemoryMutable(ap)));
1347 }
1348 
1349 bool CUDABlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1350                           std::complex<float> alpha,
1351                           const DeviceMemory<std::complex<float>> &x, int incx,
1352                           const DeviceMemory<std::complex<float>> &y, int incy,
1353                           DeviceMemory<std::complex<float>> *ap) {
1354   auto cb_alpha = GpuComplexValue(alpha);
1355   return DoBlasInternal(cublasChpr2, stream, true /* = pointer_mode_host */,
1356                         CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
1357                         GpuComplex(GpuMemory(x)), incx,
1358                         GpuComplex(GpuMemory(y)), incy,
1359                         GpuComplex(GpuMemoryMutable(ap)));
1360 }
1361 
1362 bool CUDABlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1363                           std::complex<double> alpha,
1364                           const DeviceMemory<std::complex<double>> &x, int incx,
1365                           const DeviceMemory<std::complex<double>> &y, int incy,
1366                           DeviceMemory<std::complex<double>> *ap) {
1367   auto cb_alpha = GpuComplexValue(alpha);
1368   return DoBlasInternal(cublasZhpr2, stream, true /* = pointer_mode_host */,
1369                         CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
1370                         GpuComplex(GpuMemory(x)), incx,
1371                         GpuComplex(GpuMemory(y)), incy,
1372                         GpuComplex(GpuMemoryMutable(ap)));
1373 }
1374 
1375 bool CUDABlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1376                           uint64 k, float alpha, const DeviceMemory<float> &a,
1377                           int lda, const DeviceMemory<float> &x, int incx,
1378                           float beta, DeviceMemory<float> *y, int incy) {
1379   return DoBlasInternal(cublasSsbmv, stream, true /* = pointer_mode_host */,
1380                         CUDABlasUpperLower(uplo), n, k, &alpha, GpuMemory(a),
1381                         lda, GpuMemory(x), incx, &beta, GpuMemoryMutable(y),
1382                         incy);
1383 }
1384 
1385 bool CUDABlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1386                           uint64 k, double alpha, const DeviceMemory<double> &a,
1387                           int lda, const DeviceMemory<double> &x, int incx,
1388                           double beta, DeviceMemory<double> *y, int incy) {
1389   return DoBlasInternal(cublasDsbmv, stream, true /* = pointer_mode_host */,
1390                         CUDABlasUpperLower(uplo), n, k, &alpha, GpuMemory(a),
1391                         lda, GpuMemory(x), incx, &beta, GpuMemoryMutable(y),
1392                         incy);
1393 }
1394 
1395 bool CUDABlas::DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1396                           float alpha, const DeviceMemory<float> &ap,
1397                           const DeviceMemory<float> &x, int incx, float beta,
1398                           DeviceMemory<float> *y, int incy) {
1399   return DoBlasInternal(cublasSspmv, stream, true /* = pointer_mode_host */,
1400                         CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(ap),
1401                         GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy);
1402 }
1403 
1404 bool CUDABlas::DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1405                           double alpha, const DeviceMemory<double> &ap,
1406                           const DeviceMemory<double> &x, int incx, double beta,
1407                           DeviceMemory<double> *y, int incy) {
1408   return DoBlasInternal(cublasDspmv, stream, true /* = pointer_mode_host */,
1409                         CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(ap),
1410                         GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy);
1411 }
1412 
1413 bool CUDABlas::DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
1414                          float alpha, const DeviceMemory<float> &x, int incx,
1415                          DeviceMemory<float> *ap) {
1416   return DoBlasInternal(cublasSspr, stream, true /* = pointer_mode_host */,
1417                         CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1418                         GpuMemoryMutable(ap));
1419 }
1420 
1421 bool CUDABlas::DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
1422                          double alpha, const DeviceMemory<double> &x, int incx,
1423                          DeviceMemory<double> *ap) {
1424   return DoBlasInternal(cublasDspr, stream, true /* = pointer_mode_host */,
1425                         CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1426                         GpuMemoryMutable(ap));
1427 }
1428 
1429 bool CUDABlas::DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1430                           float alpha, const DeviceMemory<float> &x, int incx,
1431                           const DeviceMemory<float> &y, int incy,
1432                           DeviceMemory<float> *ap) {
1433   return DoBlasInternal(cublasSspr2, stream, true /* = pointer_mode_host */,
1434                         CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1435                         GpuMemory(y), incy, GpuMemoryMutable(ap));
1436 }
1437 
1438 bool CUDABlas::DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1439                           double alpha, const DeviceMemory<double> &x, int incx,
1440                           const DeviceMemory<double> &y, int incy,
1441                           DeviceMemory<double> *ap) {
1442   return DoBlasInternal(cublasDspr2, stream, true /* = pointer_mode_host */,
1443                         CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1444                         GpuMemory(y), incy, GpuMemoryMutable(ap));
1445 }
1446 
1447 bool CUDABlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
1448                           float alpha, const DeviceMemory<float> &a, int lda,
1449                           const DeviceMemory<float> &x, int incx, float beta,
1450                           DeviceMemory<float> *y, int incy) {
1451   return DoBlasInternal(cublasSsymv, stream, true /* = pointer_mode_host */,
1452                         CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(a), lda,
1453                         GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy);
1454 }
1455 
1456 bool CUDABlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
1457                           double alpha, const DeviceMemory<double> &a, int lda,
1458                           const DeviceMemory<double> &x, int incx, double beta,
1459                           DeviceMemory<double> *y, int incy) {
1460   return DoBlasInternal(cublasDsymv, stream, true /* = pointer_mode_host */,
1461                         CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(a), lda,
1462                         GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy);
1463 }
1464 
1465 bool CUDABlas::DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
1466                          float alpha, const DeviceMemory<float> &x, int incx,
1467                          DeviceMemory<float> *a, int lda) {
1468   return DoBlasInternal(cublasSsyr, stream, true /* = pointer_mode_host */,
1469                         CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1470                         GpuMemoryMutable(a), lda);
1471 }
1472 
1473 bool CUDABlas::DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
1474                          double alpha, const DeviceMemory<double> &x, int incx,
1475                          DeviceMemory<double> *a, int lda) {
1476   return DoBlasInternal(cublasDsyr, stream, true /* = pointer_mode_host */,
1477                         CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1478                         GpuMemoryMutable(a), lda);
1479 }
1480 
1481 bool CUDABlas::DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1482                           float alpha, const DeviceMemory<float> &x, int incx,
1483                           const DeviceMemory<float> &y, int incy,
1484                           DeviceMemory<float> *a, int lda) {
1485   return DoBlasInternal(cublasSsyr2, stream, true /* = pointer_mode_host */,
1486                         CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1487                         GpuMemory(y), incy, GpuMemoryMutable(a), lda);
1488 }
1489 
1490 bool CUDABlas::DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1491                           double alpha, const DeviceMemory<double> &x, int incx,
1492                           const DeviceMemory<double> &y, int incy,
1493                           DeviceMemory<double> *a, int lda) {
1494   return DoBlasInternal(cublasDsyr2, stream, true /* = pointer_mode_host */,
1495                         CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1496                         GpuMemory(y), incy, GpuMemoryMutable(a), lda);
1497 }
1498 
1499 bool CUDABlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
1500                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1501                           uint64 k, const DeviceMemory<float> &a, int lda,
1502                           DeviceMemory<float> *x, int incx) {
1503   return DoBlasInternal(cublasStbmv, stream, true /* = pointer_mode_host */,
1504                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1505                         CUDABlasDiagonal(diag), n, k, GpuMemory(a), lda,
1506                         GpuMemoryMutable(x), incx);
1507 }
1508 
1509 bool CUDABlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
1510                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1511                           uint64 k, const DeviceMemory<double> &a, int lda,
1512                           DeviceMemory<double> *x, int incx) {
1513   return DoBlasInternal(cublasDtbmv, stream, true /* = pointer_mode_host */,
1514                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1515                         CUDABlasDiagonal(diag), n, k, GpuMemory(a), lda,
1516                         GpuMemoryMutable(x), incx);
1517 }
1518 
1519 bool CUDABlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
1520                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1521                           uint64 k, const DeviceMemory<std::complex<float>> &a,
1522                           int lda, DeviceMemory<std::complex<float>> *x,
1523                           int incx) {
1524   return DoBlasInternal(cublasCtbmv, stream, true /* = pointer_mode_host */,
1525                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1526                         CUDABlasDiagonal(diag), n, k, GpuComplex(GpuMemory(a)),
1527                         lda, GpuComplex(GpuMemoryMutable(x)), incx);
1528 }
1529 
1530 bool CUDABlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
1531                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1532                           uint64 k, const DeviceMemory<std::complex<double>> &a,
1533                           int lda, DeviceMemory<std::complex<double>> *x,
1534                           int incx) {
1535   return DoBlasInternal(cublasZtbmv, stream, true /* = pointer_mode_host */,
1536                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1537                         CUDABlasDiagonal(diag), n, k, GpuComplex(GpuMemory(a)),
1538                         lda, GpuComplex(GpuMemoryMutable(x)), incx);
1539 }
1540 
1541 bool CUDABlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
1542                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1543                           uint64 k, const DeviceMemory<float> &a, int lda,
1544                           DeviceMemory<float> *x, int incx) {
1545   return DoBlasInternal(cublasStbsv, stream, true /* = pointer_mode_host */,
1546                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1547                         CUDABlasDiagonal(diag), n, k, GpuMemory(a), lda,
1548                         GpuMemoryMutable(x), incx);
1549 }
1550 
1551 bool CUDABlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
1552                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1553                           uint64 k, const DeviceMemory<double> &a, int lda,
1554                           DeviceMemory<double> *x, int incx) {
1555   return DoBlasInternal(cublasDtbsv, stream, true /* = pointer_mode_host */,
1556                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1557                         CUDABlasDiagonal(diag), n, k, GpuMemory(a), lda,
1558                         GpuMemoryMutable(x), incx);
1559 }
1560 
1561 bool CUDABlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
1562                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1563                           uint64 k, const DeviceMemory<std::complex<float>> &a,
1564                           int lda, DeviceMemory<std::complex<float>> *x,
1565                           int incx) {
1566   return DoBlasInternal(cublasCtbsv, stream, true /* = pointer_mode_host */,
1567                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1568                         CUDABlasDiagonal(diag), n, k, GpuComplex(GpuMemory(a)),
1569                         lda, GpuComplex(GpuMemoryMutable(x)), incx);
1570 }
1571 
1572 bool CUDABlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
1573                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1574                           uint64 k, const DeviceMemory<std::complex<double>> &a,
1575                           int lda, DeviceMemory<std::complex<double>> *x,
1576                           int incx) {
1577   return DoBlasInternal(cublasZtbsv, stream, true /* = pointer_mode_host */,
1578                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1579                         CUDABlasDiagonal(diag), n, k, GpuComplex(GpuMemory(a)),
1580                         lda, GpuComplex(GpuMemoryMutable(x)), incx);
1581 }
1582 
1583 bool CUDABlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
1584                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1585                           const DeviceMemory<float> &ap, DeviceMemory<float> *x,
1586                           int incx) {
1587   return DoBlasInternal(cublasStpmv, stream, true /* = pointer_mode_host */,
1588                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1589                         CUDABlasDiagonal(diag), n, GpuMemory(ap),
1590                         GpuMemoryMutable(x), incx);
1591 }
1592 
1593 bool CUDABlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
1594                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1595                           const DeviceMemory<double> &ap,
1596                           DeviceMemory<double> *x, int incx) {
1597   return DoBlasInternal(cublasDtpmv, stream, true /* = pointer_mode_host */,
1598                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1599                         CUDABlasDiagonal(diag), n, GpuMemory(ap),
1600                         GpuMemoryMutable(x), incx);
1601 }
1602 
1603 bool CUDABlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
1604                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1605                           const DeviceMemory<std::complex<float>> &ap,
1606                           DeviceMemory<std::complex<float>> *x, int incx) {
1607   return DoBlasInternal(cublasCtpmv, stream, true /* = pointer_mode_host */,
1608                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1609                         CUDABlasDiagonal(diag), n, GpuComplex(GpuMemory(ap)),
1610                         GpuComplex(GpuMemoryMutable(x)), incx);
1611 }
1612 
1613 bool CUDABlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
1614                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1615                           const DeviceMemory<std::complex<double>> &ap,
1616                           DeviceMemory<std::complex<double>> *x, int incx) {
1617   return DoBlasInternal(cublasZtpmv, stream, true /* = pointer_mode_host */,
1618                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1619                         CUDABlasDiagonal(diag), n, GpuComplex(GpuMemory(ap)),
1620                         GpuComplex(GpuMemoryMutable(x)), incx);
1621 }
1622 
1623 bool CUDABlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
1624                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1625                           const DeviceMemory<float> &ap, DeviceMemory<float> *x,
1626                           int incx) {
1627   return DoBlasInternal(cublasStpsv, stream, true /* = pointer_mode_host */,
1628                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1629                         CUDABlasDiagonal(diag), n, GpuMemory(ap),
1630                         GpuMemoryMutable(x), incx);
1631 }
1632 
1633 bool CUDABlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
1634                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1635                           const DeviceMemory<double> &ap,
1636                           DeviceMemory<double> *x, int incx) {
1637   return DoBlasInternal(cublasDtpsv, stream, true /* = pointer_mode_host */,
1638                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1639                         CUDABlasDiagonal(diag), n, GpuMemory(ap),
1640                         GpuMemoryMutable(x), incx);
1641 }
1642 
1643 bool CUDABlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
1644                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1645                           const DeviceMemory<std::complex<float>> &ap,
1646                           DeviceMemory<std::complex<float>> *x, int incx) {
1647   return DoBlasInternal(cublasCtpsv, stream, true /* = pointer_mode_host */,
1648                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1649                         CUDABlasDiagonal(diag), n, GpuComplex(GpuMemory(ap)),
1650                         GpuComplex(GpuMemoryMutable(x)), incx);
1651 }
1652 
1653 bool CUDABlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
1654                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1655                           const DeviceMemory<std::complex<double>> &ap,
1656                           DeviceMemory<std::complex<double>> *x, int incx) {
1657   return DoBlasInternal(cublasZtpsv, stream, true /* = pointer_mode_host */,
1658                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1659                         CUDABlasDiagonal(diag), n, GpuComplex(GpuMemory(ap)),
1660                         GpuComplex(GpuMemoryMutable(x)), incx);
1661 }
1662 
1663 bool CUDABlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
1664                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1665                           const DeviceMemory<float> &a, int lda,
1666                           DeviceMemory<float> *x, int incx) {
1667   return DoBlasInternal(cublasStrmv, stream, true /* = pointer_mode_host */,
1668                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1669                         CUDABlasDiagonal(diag), n, GpuMemory(a), lda,
1670                         GpuMemoryMutable(x), incx);
1671 }
1672 
1673 bool CUDABlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
1674                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1675                           const DeviceMemory<double> &a, int lda,
1676                           DeviceMemory<double> *x, int incx) {
1677   return DoBlasInternal(cublasDtrmv, stream, true /* = pointer_mode_host */,
1678                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1679                         CUDABlasDiagonal(diag), n, GpuMemory(a), lda,
1680                         GpuMemoryMutable(x), incx);
1681 }
1682 
1683 bool CUDABlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
1684                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1685                           const DeviceMemory<std::complex<float>> &a, int lda,
1686                           DeviceMemory<std::complex<float>> *x, int incx) {
1687   return DoBlasInternal(cublasCtrmv, stream, true /* = pointer_mode_host */,
1688                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1689                         CUDABlasDiagonal(diag), n, GpuComplex(GpuMemory(a)),
1690                         lda, GpuComplex(GpuMemoryMutable(x)), incx);
1691 }
1692 
1693 bool CUDABlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
1694                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1695                           const DeviceMemory<std::complex<double>> &a, int lda,
1696                           DeviceMemory<std::complex<double>> *x, int incx) {
1697   return DoBlasInternal(cublasZtrmv, stream, true /* = pointer_mode_host */,
1698                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1699                         CUDABlasDiagonal(diag), n, GpuComplex(GpuMemory(a)),
1700                         lda, GpuComplex(GpuMemoryMutable(x)), incx);
1701 }
1702 
1703 bool CUDABlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1704                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1705                           const DeviceMemory<float> &a, int lda,
1706                           DeviceMemory<float> *x, int incx) {
1707   return DoBlasInternal(cublasStrsv, stream, true /* = pointer_mode_host */,
1708                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1709                         CUDABlasDiagonal(diag), n, GpuMemory(a), lda,
1710                         GpuMemoryMutable(x), incx);
1711 }
1712 
1713 bool CUDABlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1714                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1715                           const DeviceMemory<double> &a, int lda,
1716                           DeviceMemory<double> *x, int incx) {
1717   return DoBlasInternal(cublasDtrsv, stream, true /* = pointer_mode_host */,
1718                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1719                         CUDABlasDiagonal(diag), n, GpuMemory(a), lda,
1720                         GpuMemoryMutable(x), incx);
1721 }
1722 
1723 bool CUDABlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1724                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1725                           const DeviceMemory<std::complex<float>> &a, int lda,
1726                           DeviceMemory<std::complex<float>> *x, int incx) {
1727   return DoBlasInternal(cublasCtrsv, stream, true /* = pointer_mode_host */,
1728                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1729                         CUDABlasDiagonal(diag), n, GpuComplex(GpuMemory(a)),
1730                         lda, GpuComplex(GpuMemoryMutable(x)), incx);
1731 }
1732 
1733 bool CUDABlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1734                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1735                           const DeviceMemory<std::complex<double>> &a, int lda,
1736                           DeviceMemory<std::complex<double>> *x, int incx) {
1737   return DoBlasInternal(cublasZtrsv, stream, true /* = pointer_mode_host */,
1738                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1739                         CUDABlasDiagonal(diag), n, GpuComplex(GpuMemory(a)),
1740                         lda, GpuComplex(GpuMemoryMutable(x)), incx);
1741 }
1742 
1743 port::Status CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
1744                                   blas::Transpose transb, uint64 m, uint64 n,
1745                                   uint64 k, blas::DataType dtype,
1746                                   const void *alpha, const DeviceMemoryBase &a,
1747                                   int lda, const DeviceMemoryBase &b, int ldb,
1748                                   const void *beta, DeviceMemoryBase *c,
1749                                   int ldc) {
1750   cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
1751 
1752 #if CUDA_VERSION < 11000
1753   if (dtype == blas::DataType::kHalf) {
1754     math_type = CUBLAS_TENSOR_OP_MATH;
1755   }
1756 #else
1757   if (dtype == blas::DataType::kFloat) {
1758     math_type = CUBLAS_TF32_TENSOR_OP_MATH;
1759     if (stream->GetCudaComputeCapability().IsAtLeast(
1760             CudaComputeCapability::AMPERE)) {
1761       // TODO(reedwm): Remove or make this VLOG(1) once TensorFloat-32 is more
1762       // well tested.
1763       if (tensorflow::tensor_float_32_execution_enabled()) {
1764         LOG_FIRST_N(INFO, 1) << "TensorFloat-32 will be used for the matrix "
1765                                 "multiplication. This will only be logged "
1766                                 "once.";
1767       }
1768     }
1769   }
1770 #endif
1771 
1772   // TODO(cheshire): Return an error instead.
1773   // TODO(cheshire): Why are these checked only for `half` and `float`?
1774   if (dtype == blas::DataType::kHalf || dtype == blas::DataType::kFloat) {
1775     if (transa == blas::Transpose::kNoTranspose) {
1776       if (lda < static_cast<int64>(m)) {
1777         LOG(WARNING) << "GEMM lda was smaller than m (no transpose case); "
1778                         "precondition violation";
1779       }
1780     } else {
1781       if (lda < static_cast<int64>(k)) {
1782         LOG(WARNING) << "GEMM lda (" << lda << ") was smaller than k (" << k
1783                      << ") (transpose case); precondition violation";
1784       }
1785     }
1786     if (transb == blas::Transpose::kNoTranspose) {
1787       if (ldb < static_cast<int64>(k)) {
1788         LOG(WARNING) << "GEMM ldb (" << ldb << ") was smaller than k (" << k
1789                      << ") (no transpose case); precondition violation";
1790       }
1791     } else {
1792       if (ldb < static_cast<int64>(n)) {
1793         LOG(WARNING) << "GEMM ldb was smaller than n (transpose case); "
1794                         "precondition violation";
1795       }
1796     }
1797   }
1798 
1799   VLOG(1) << absl::StrFormat(
1800       "doing cuBLAS SGEMM: at=%d bt=%d m=%u n=%u "
1801       "k=%u alpha=%p a=%p lda=%d b=%p ldb=%d beta=%p "
1802       "c=%p ldc=%d",
1803       static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
1804       a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
1805 
1806   switch (dtype) {
1807     case blas::DataType::kHalf: {
1808 #if CUDA_VERSION < 7050
1809       return port::InternalError(
1810           "fp16 sgemm is not implemented in this cuBLAS version "
1811           "(need at least CUDA 7.5)");
1812 #endif
1813 
1814       return DoBlasInternalImpl(
1815           cublasSgemmEx, stream, true /* = pointer_mode_host */, math_type,
1816           CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
1817           static_cast<const float *>(alpha), a.opaque(), SE_CUDA_DATA_HALF, lda,
1818           b.opaque(), SE_CUDA_DATA_HALF, ldb, static_cast<const float *>(beta),
1819           c->opaque(), SE_CUDA_DATA_HALF, ldc);
1820     }
1821 #if CUDA_VERSION > 11000
1822     case blas::DataType::kBF16: {
1823       return DoBlasInternalImpl(
1824           cublasSgemmEx, stream, true /* = pointer_mode_host */, math_type,
1825           CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
1826           static_cast<const float *>(alpha), a.opaque(), CUDA_R_16BF, lda,
1827           b.opaque(), CUDA_R_16BF, ldb, static_cast<const float *>(beta),
1828           c->opaque(), CUDA_R_16BF, ldc);
1829     }
1830 #endif
1831     case dnn::kFloat:
1832       return DoBlasInternalImpl(
1833           cublasSgemm, stream, true /* = pointer_mode_host */, math_type,
1834           CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
1835           static_cast<const float *>(alpha),
1836           static_cast<const float *>(a.opaque()), lda,
1837           static_cast<const float *>(b.opaque()), ldb,
1838           static_cast<const float *>(beta), static_cast<float *>(c->opaque()),
1839           ldc);
1840     case dnn::kDouble:
1841       return DoBlasInternalImpl(
1842           cublasDgemm, stream, true /* = pointer_mode_host */, math_type,
1843           CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
1844           static_cast<const double *>(alpha),
1845           static_cast<const double *>(a.opaque()), lda,
1846           static_cast<const double *>(b.opaque()), ldb,
1847           static_cast<const double *>(beta), static_cast<double *>(c->opaque()),
1848           ldc);
1849     case dnn::kComplexFloat: {
1850       GpuComplexType cb_alpha =
1851           GpuComplexValue(*static_cast<const std::complex<float> *>(alpha));
1852       GpuComplexType cb_beta =
1853           GpuComplexValue(*static_cast<const std::complex<float> *>(beta));
1854       return DoBlasInternalImpl(
1855           cublasCgemm, stream, true /* = pointer_mode_host */, math_type,
1856           CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
1857           &cb_alpha, static_cast<const GpuComplexType *>(a.opaque()), lda,
1858           static_cast<const GpuComplexType *>(b.opaque()), ldb, &cb_beta,
1859           static_cast<GpuComplexType *>(c->opaque()), ldc);
1860     }
1861     case dnn::kComplexDouble: {
1862       GpuDoubleComplexType cb_alpha =
1863           GpuComplexValue(*static_cast<const std::complex<double> *>(alpha));
1864       GpuDoubleComplexType cb_beta =
1865           GpuComplexValue(*static_cast<const std::complex<double> *>(beta));
1866       return DoBlasInternalImpl(
1867           cublasZgemm, stream, true /* = pointer_mode_host */, math_type,
1868           CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
1869           &cb_alpha, static_cast<const GpuDoubleComplexType *>(a.opaque()), lda,
1870           static_cast<const GpuDoubleComplexType *>(b.opaque()), ldb, &cb_beta,
1871           static_cast<GpuDoubleComplexType *>(c->opaque()), ldc);
1872     }
1873     default:
1874       return port::InternalError(absl::StrCat("Unsupported datatype for GEMM: ",
1875                                               blas::DataTypeString(dtype)));
1876   }
1877 }
1878 
1879 bool CUDABlas::DoBlasGemvWithProfiling(
1880     Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha,
1881     const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
1882     int incx, float beta, DeviceMemory<float> *y, int incy,
1883     blas::ProfileResult *output_profile_result) {
1884   return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
1885                                      incx, beta, y, incy,
1886                                      output_profile_result);
1887 }
1888 
1889 bool CUDABlas::DoBlasGemvWithProfiling(
1890     Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha,
1891     const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
1892     int incx, double beta, DeviceMemory<double> *y, int incy,
1893     blas::ProfileResult *output_profile_result) {
1894   return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
1895                                      incx, beta, y, incy,
1896                                      output_profile_result);
1897 }
1898 
1899 bool CUDABlas::DoBlasGemvWithProfiling(
1900     Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
1901     std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a,
1902     int lda, const DeviceMemory<std::complex<float>> &x, int incx,
1903     std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
1904     blas::ProfileResult *output_profile_result) {
1905   return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
1906                                      incx, beta, y, incy,
1907                                      output_profile_result);
1908 }
1909 
1910 bool CUDABlas::DoBlasGemvWithProfiling(
1911     Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
1912     std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a,
1913     int lda, const DeviceMemory<std::complex<double>> &x, int incx,
1914     std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy,
1915     blas::ProfileResult *output_profile_result) {
1916   return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
1917                                      incx, beta, y, incy,
1918                                      output_profile_result);
1919 }
1920 
1921 bool CUDABlas::DoBlasGemmWithProfiling(
1922     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1923     uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
1924     int lda, const DeviceMemory<Eigen::half> &b, int ldb, float beta,
1925     DeviceMemory<Eigen::half> *c, int ldc,
1926     blas::ProfileResult *output_profile_result) {
1927   return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1928                                      lda, b, ldb, beta, c, ldc,
1929                                      output_profile_result);
1930 }
1931 
1932 bool CUDABlas::DoBlasGemmWithProfiling(
1933     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1934     uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
1935     const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
1936     int ldc, blas::ProfileResult *output_profile_result) {
1937   return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1938                                      lda, b, ldb, beta, c, ldc,
1939                                      output_profile_result);
1940 }
1941 
1942 bool CUDABlas::DoBlasGemmWithProfiling(
1943     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1944     uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
1945     const DeviceMemory<double> &b, int ldb, double beta,
1946     DeviceMemory<double> *c, int ldc,
1947     blas::ProfileResult *output_profile_result) {
1948   return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1949                                      lda, b, ldb, beta, c, ldc,
1950                                      output_profile_result);
1951 }
1952 
1953 bool CUDABlas::DoBlasGemmWithProfiling(
1954     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1955     uint64 n, uint64 k, std::complex<float> alpha,
1956     const DeviceMemory<std::complex<float>> &a, int lda,
1957     const DeviceMemory<std::complex<float>> &b, int ldb,
1958     std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
1959     blas::ProfileResult *output_profile_result) {
1960   return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1961                                      lda, b, ldb, beta, c, ldc,
1962                                      output_profile_result);
1963 }
1964 
1965 bool CUDABlas::DoBlasGemmWithProfiling(
1966     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1967     uint64 n, uint64 k, std::complex<double> alpha,
1968     const DeviceMemory<std::complex<double>> &a, int lda,
1969     const DeviceMemory<std::complex<double>> &b, int ldb,
1970     std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
1971     blas::ProfileResult *output_profile_result) {
1972   return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1973                                      lda, b, ldb, beta, c, ldc,
1974                                      output_profile_result);
1975 }
1976 
1977 template <typename T>
1978 bool CUDABlas::DoBlasGemvWithProfilingImpl(
1979     Stream *stream, blas::Transpose trans, uint64 m, uint64 n, const T &alpha,
1980     const DeviceMemory<T> &a, int lda, const DeviceMemory<T> &x, int incx,
1981     const T &beta, DeviceMemory<T> *y, int incy,
1982     blas::ProfileResult *output_profile_result) {
1983   std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
1984   if (output_profile_result != nullptr) {
1985     timer.reset(new GpuTimer(parent_));
1986     if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
1987       return false;
1988     }
1989   }
1990 
1991   // Call blasGemm
1992   bool result =
1993       DoBlasGemv(stream, trans, m, n, alpha, a, lda, x, incx, beta, y, incy);
1994 
1995   if (timer != nullptr && result) {
1996     // GpuTimer will CHECK-fail if we Stop() it while the stream is in an error
1997     // state.
1998     if (!timer->Stop(AsGpuStream(stream))) {
1999       return false;
2000     }
2001     output_profile_result->set_is_valid(true);
2002     output_profile_result->set_algorithm(blas::kDefaultBlasGemv);
2003     output_profile_result->set_elapsed_time_in_ms(
2004         timer->GetElapsedMilliseconds());
2005   }
2006   return result;
2007 }
2008 
2009 template <typename T, typename ParamType>
2010 bool CUDABlas::DoBlasGemmWithProfilingImpl(
2011     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2012     uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a,
2013     int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta,
2014     DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result) {
2015   std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
2016   if (output_profile_result != nullptr) {
2017     timer.reset(new GpuTimer(parent_));
2018     if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
2019       return false;
2020     }
2021   }
2022 
2023   // Call blasGemm
2024   bool result =
2025       DoBlasGemm(stream, transa, transb, m, n, k, blas::ToDataType<T>::value,
2026                  &alpha, a, lda, b, ldb, &beta, c, ldc)
2027           .ok();
2028 
2029   if (timer != nullptr && result) {
2030     // GpuTimer will CHECK-fail if we Stop() it while the stream is in an error
2031     // state.
2032     if (!timer->Stop(AsGpuStream(stream))) {
2033       return false;
2034     }
2035     output_profile_result->set_is_valid(true);
2036     output_profile_result->set_algorithm(blas::kDefaultBlasGemm);
2037     output_profile_result->set_elapsed_time_in_ms(
2038         timer->GetElapsedMilliseconds());
2039   }
2040   return result;
2041 }
2042 
2043 static bool UsesTensorOps(blas::AlgorithmType algo) {
2044 #if CUDA_VERSION >= 9000
2045   cublasGemmAlgo_t cublas_algo = static_cast<cublasGemmAlgo_t>(algo);
2046   return cublas_algo >= CUBLAS_GEMM_DEFAULT_TENSOR_OP;
2047 #else
2048   return false;
2049 #endif
2050 }
2051 
2052 static port::StatusOr<cublasMath_t> GetMathTypeForGemmEx(
2053     Stream *stream, blas::AlgorithmType algorithm, blas::DataType type_a,
2054     blas::DataType type_b) {
2055   if (type_a != type_b) {
2056     return port::InternalError("Types of inputs mismatch");
2057   }
2058 
2059   // GPUs < sm_50 don't support cublasGemmEx.
2060   CudaComputeCapability cc = stream->GetCudaComputeCapability();
2061   if (cc.major < 5) {
2062     return port::InternalError(absl::StrCat(
2063         "sm_", cc.major, " does not support explicit gemm algorithms."));
2064   }
2065 
2066   bool algo_uses_tensor_ops = UsesTensorOps(algorithm);
2067   cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
2068   if (algo_uses_tensor_ops) {
2069     if (cc.major < 7) {
2070       return port::InternalError(absl::StrCat(
2071           "Algorithm ", algorithm,
2072           " uses tensor ops, but tensor ops are not available in sm", cc.major,
2073           "X devices."));
2074     } else if (type_a == blas::DataType::kFloat) {
2075 #if CUDA_VERSION < 11000
2076       return port::InternalError(absl::StrCat(
2077           "Algorithm ", algorithm,
2078           " uses tensor ops, but tensor ops are not available for fp32"));
2079 #else
2080       if (cc.major < 8) {
2081         return port::InternalError(absl::StrCat(
2082             "Algorithm ", algorithm,
2083             " uses tensor ops, but tensor ops are not available in sm",
2084             cc.major, "X devices for float input types."));
2085       } else if (!tensorflow::tensor_float_32_execution_enabled()) {
2086         return port::InternalError(absl::StrCat(
2087             "Algorithm ", algorithm,
2088             " uses tensor ops, but tensor ops are disabled for fp32 inputs"));
2089       }
2090       math_type = CUBLAS_TF32_TENSOR_OP_MATH;
2091 #endif
2092     } else if (type_a == blas::DataType::kHalf) {
2093 #if CUDA_VERSION < 11000
2094       math_type = CUBLAS_TENSOR_OP_MATH;
2095 #endif
2096     } else {
2097       return port::InternalError(
2098           absl::StrCat("Algorithm ", algorithm,
2099                        " uses tensor ops which are not supported for input"));
2100     }
2101   }
2102 
2103   // Return false if we might be hitting a cuBLAS bug that produces the wrong
2104   // result. See nvbugs/2156201, b/79126339.
2105 #if CUDA_VERSION >= 9000 && CUDA_VERSION < 9020
2106   if ((algorithm == CUBLAS_GEMM_DEFAULT || algorithm >= CUBLAS_GEMM_ALGO13) &&
2107       std::max({m, n, k}) >= 2097153 && cc_major < 7) {
2108     return port::InternalError(
2109         "DoBlasGemmWithAlgorithm returning false to work around cudnn "
2110         "<9.2 bug with m, n, or k >= 2097153.  See b/79126339.");
2111   }
2112 #endif
2113   return math_type;
2114 }
2115 
2116 static port::StatusOr<std::unique_ptr<GpuTimer, GpuTimerDeleter>>
2117 StartGpuTimerForProfile(Stream *stream, GpuExecutor *executor,
2118                         blas::ProfileResult *output_profile_result) {
2119   std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
2120   if (output_profile_result) {
2121     timer.reset(new GpuTimer(executor));
2122     if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
2123       return port::InternalError(
2124           "output_profile_result given, but unable to create a GpuTimer");
2125     }
2126   }
2127   return timer;
2128 }
2129 
2130 static port::Status PopulateProfileFromTimer(
2131     GpuTimer *timer, blas::AlgorithmType algorithm,
2132     blas::ProfileResult *output_profile_result, Stream *stream) {
2133   if (timer) {
2134     // GpuTimer will CHECK-fail if we Stop() it while the stream is in an error
2135     // state.
2136     if (!timer->Stop(AsGpuStream(stream))) {
2137       return port::InternalError("unable to stop GpuTimer.");
2138     }
2139     output_profile_result->set_is_valid(true);
2140     output_profile_result->set_algorithm(algorithm);
2141     output_profile_result->set_elapsed_time_in_ms(
2142         timer->GetElapsedMilliseconds());
2143   }
2144   return port::Status::OK();
2145 }
2146 
2147 port::Status CUDABlas::DoBlasGemmWithAlgorithm(
2148     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2149     uint64 n, uint64 k, const void *alpha, const DeviceMemoryBase &a,
2150     blas::DataType type_a, int lda, const DeviceMemoryBase &b,
2151     blas::DataType type_b, int ldb, const void *beta, DeviceMemoryBase *c,
2152     blas::DataType type_c, int ldc, blas::ComputationType computation_type,
2153     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
2154   TF_ASSIGN_OR_RETURN(cublasMath_t math_type,
2155                       GetMathTypeForGemmEx(stream, algorithm, type_a, type_b));
2156 
2157   TF_ASSIGN_OR_RETURN(auto timer, StartGpuTimerForProfile(
2158                                       stream, parent_, output_profile_result));
2159 
2160   // Since we are converting 'algorithm' to cublasGemmAlgo_t by static_cast,
2161   // we do the following compile-time check on the default value:
2162   static_assert(blas::kDefaultGemmAlgo == CUBLAS_GEMM_DFALT, "");
2163 
2164   TF_RETURN_IF_ERROR(DoBlasInternalImpl(
2165       AS_LAMBDA(cublasGemmEx), stream, /*pointer_mode_host=*/true, math_type,
2166       CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, alpha,
2167       a.opaque(), GetCUDADataType(type_a), lda, b.opaque(),
2168       GetCUDADataType(type_b), ldb, beta, c->opaque(), GetCUDADataType(type_c),
2169       ldc, CUDAComputationType(computation_type),
2170       static_cast<cublasGemmAlgo_t>(algorithm)));
2171   TF_RETURN_IF_ERROR(PopulateProfileFromTimer(timer.get(), algorithm,
2172                                               output_profile_result, stream));
2173   return port::Status::OK();
2174 }
2175 
2176 port::Status CUDABlas::DoBlasGemmStridedBatchedWithAlgorithm(
2177     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2178     uint64 n, uint64 k, const void *alpha, const DeviceMemoryBase &a,
2179     blas::DataType type_a, int lda, int64_t stride_a, const DeviceMemoryBase &b,
2180     blas::DataType type_b, int ldb, int64_t stride_b, const void *beta,
2181     DeviceMemoryBase *c, blas::DataType type_c, int ldc, int64_t stride_c,
2182     int batch_count, blas::ComputationType computation_type,
2183     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
2184   TF_ASSIGN_OR_RETURN(cublasMath_t math_type,
2185                       GetMathTypeForGemmEx(stream, algorithm, type_a, type_b));
2186   TF_ASSIGN_OR_RETURN(auto timer, StartGpuTimerForProfile(
2187                                       stream, parent_, output_profile_result));
2188 
2189   cudaDataType_t cuda_in_type = GetCUDADataType(type_a);
2190 
2191 #if CUDA_VERSION >= 11000
2192   // Workaround CUDA bug where batched GEMM is erroneously marked as
2193   // unsupported by manually unbatching it on Pascal.
2194   if (cuda_in_type == CUDA_R_16BF &&
2195       !stream->GetCudaComputeCapability().IsAtLeast(7)) {
2196     for (int batch = 0; batch < batch_count; ++batch) {
2197       const auto *a_matrix = reinterpret_cast<const __nv_bfloat16 *>(
2198           static_cast<const Eigen::bfloat16 *>(a.opaque()) + batch * stride_a);
2199       const auto *b_matrix = reinterpret_cast<const __nv_bfloat16 *>(
2200           static_cast<const Eigen::bfloat16 *>(b.opaque()) + batch * stride_b);
2201       auto *c_matrix = reinterpret_cast<__nv_bfloat16 *>(
2202           static_cast<Eigen::bfloat16 *>(c->opaque()) + batch * stride_c);
2203       TF_RETURN_IF_ERROR(DoBlasInternalImpl(
2204           AS_LAMBDA(cublasGemmEx), stream, /*pointer_mode_host=*/true,
2205           math_type, CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n,
2206           k, static_cast<const float *>(alpha), a_matrix, CUDA_R_16BF, lda,
2207           b_matrix, CUDA_R_16BF, ldb, static_cast<const float *>(beta),
2208           c_matrix, CUDA_R_16BF, ldc, CUDAComputationType(computation_type),
2209           static_cast<cublasGemmAlgo_t>(algorithm)));
2210     }
2211     TF_RETURN_IF_ERROR(PopulateProfileFromTimer(timer.get(), algorithm,
2212                                                 output_profile_result, stream));
2213     return port::Status::OK();
2214   }
2215 #endif
2216 
2217   TF_RETURN_IF_ERROR(DoBlasInternalImpl(
2218       AS_LAMBDA(cublasGemmStridedBatchedEx), stream, /*pointer_mode_host=*/true,
2219       math_type, CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
2220       alpha, a.opaque(), cuda_in_type, lda, stride_a, b.opaque(), cuda_in_type,
2221       ldb, stride_b, beta, c->opaque(), GetCUDADataType(type_c), ldc, stride_c,
2222       batch_count, CUDAComputationType(computation_type),
2223       static_cast<cublasGemmAlgo_t>(algorithm)));
2224   TF_RETURN_IF_ERROR(PopulateProfileFromTimer(timer.get(), algorithm,
2225                                               output_profile_result, stream));
2226   return port::Status::OK();
2227 }
2228 
2229 bool CUDABlas::GetBlasGemmAlgorithms(
2230     std::vector<blas::AlgorithmType> *out_algorithms) {
2231   // cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx)
2232   // were first introduced in CUDA 8.
2233   //
2234   // Note that when CUDA version and compute capability is not sufficient, we
2235   // still return the out_algorithms. Caller needs to make sure that in this
2236   // case, the returned vector is empty.
2237   *out_algorithms = {
2238     CUBLAS_GEMM_DFALT,
2239     CUBLAS_GEMM_ALGO0,
2240     CUBLAS_GEMM_ALGO1,
2241     CUBLAS_GEMM_ALGO2,
2242     CUBLAS_GEMM_ALGO3,
2243     CUBLAS_GEMM_ALGO4,
2244     CUBLAS_GEMM_ALGO5,
2245     CUBLAS_GEMM_ALGO6,
2246     CUBLAS_GEMM_ALGO7,
2247 #if CUDA_VERSION >= 9000
2248     CUBLAS_GEMM_ALGO8,
2249     CUBLAS_GEMM_ALGO9,
2250     CUBLAS_GEMM_ALGO10,
2251     CUBLAS_GEMM_ALGO11,
2252     CUBLAS_GEMM_ALGO12,
2253     CUBLAS_GEMM_ALGO13,
2254     CUBLAS_GEMM_ALGO14,
2255     CUBLAS_GEMM_ALGO15,
2256     CUBLAS_GEMM_ALGO16,
2257     CUBLAS_GEMM_ALGO17,
2258     CUBLAS_GEMM_DFALT_TENSOR_OP,
2259     CUBLAS_GEMM_ALGO0_TENSOR_OP,
2260     CUBLAS_GEMM_ALGO1_TENSOR_OP,
2261     CUBLAS_GEMM_ALGO2_TENSOR_OP,
2262     CUBLAS_GEMM_ALGO3_TENSOR_OP,
2263     CUBLAS_GEMM_ALGO4_TENSOR_OP,
2264 #endif
2265 #if CUDA_VERSION >= 9020
2266     CUBLAS_GEMM_ALGO18,
2267     CUBLAS_GEMM_ALGO19,
2268     CUBLAS_GEMM_ALGO20,
2269     CUBLAS_GEMM_ALGO21,
2270     CUBLAS_GEMM_ALGO22,
2271     CUBLAS_GEMM_ALGO23,
2272     CUBLAS_GEMM_ALGO5_TENSOR_OP,
2273     CUBLAS_GEMM_ALGO6_TENSOR_OP,
2274     CUBLAS_GEMM_ALGO7_TENSOR_OP,
2275     CUBLAS_GEMM_ALGO8_TENSOR_OP,
2276     CUBLAS_GEMM_ALGO9_TENSOR_OP,
2277     CUBLAS_GEMM_ALGO10_TENSOR_OP,
2278     CUBLAS_GEMM_ALGO11_TENSOR_OP,
2279     CUBLAS_GEMM_ALGO12_TENSOR_OP,
2280     CUBLAS_GEMM_ALGO13_TENSOR_OP,
2281     CUBLAS_GEMM_ALGO14_TENSOR_OP,
2282     CUBLAS_GEMM_ALGO15_TENSOR_OP,
2283 #endif
2284   };
2285   return true;
2286 }
2287 
2288 template <typename T>
2289 struct HalfAsFloat {
2290   typedef T type;
2291 };
2292 
2293 template <>
2294 struct HalfAsFloat<Eigen::half> {
2295   typedef float type;
2296 };
2297 
2298 namespace {
2299 // pass-through for non-complex types that don't need conversion to
2300 // cublas-specific type.
2301 template <typename T>
2302 T inline GpuComplexValue(T v) {
2303   return v;
2304 }
2305 }  // namespace
2306 
2307 template <typename T, typename Scalar, typename FuncT>
2308 port::Status CUDABlas::DoBlasGemmBatchedInternal(
2309     FuncT cublas_func, Stream *stream, blas::Transpose transa,
2310     blas::Transpose transb, uint64 m, uint64 n, uint64 k, Scalar alpha,
2311     const port::ArraySlice<DeviceMemory<T> *> &a_ptrs_to_wrappers, int lda,
2312     const port::ArraySlice<DeviceMemory<T> *> &b_ptrs_to_wrappers, int ldb,
2313     Scalar beta, const port::ArraySlice<DeviceMemory<T> *> &c_ptrs_to_wrappers,
2314     int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
2315   std::vector<T *> a_raw_ptrs, b_raw_ptrs, c_raw_ptrs;
2316   for (int i = 0; i < batch_count; ++i) {
2317     a_raw_ptrs.push_back(static_cast<T *>(a_ptrs_to_wrappers[i]->opaque()));
2318     b_raw_ptrs.push_back(static_cast<T *>(b_ptrs_to_wrappers[i]->opaque()));
2319     c_raw_ptrs.push_back(static_cast<T *>(c_ptrs_to_wrappers[i]->opaque()));
2320   }
2321 
2322   typedef typename HalfAsFloat<typename GpuComplexT<T>::type>::type CUDA_T;
2323 
2324   const size_t size = batch_count * sizeof(CUDA_T *);
2325 
2326   // Device-side copy of pointers to matrices.
2327   DeviceMemory<CUDA_T *> a;
2328   DeviceMemory<CUDA_T *> b;
2329   DeviceMemory<CUDA_T *> c;
2330 
2331   // If temporary space is allocated for device-side copies of pointers to
2332   // matrices, that temporary space should not be freed until this function
2333   // returns. Although the values for these unique_ptrs are not set here, they
2334   // are declared at this scope so they will be destroyed when the function
2335   // returns.
2336   //
2337   // If a scratch allocator is provided, these pointers will not be used at all.
2338   std::unique_ptr<TemporaryDeviceMemory<CUDA_T *>> a_temporary;
2339   std::unique_ptr<TemporaryDeviceMemory<CUDA_T *>> b_temporary;
2340   std::unique_ptr<TemporaryDeviceMemory<CUDA_T *>> c_temporary;
2341 
2342   // Decide how to allocate device-side copy of pointers to matrices based on
2343   // whether a scratch allocator was passed.
2344   if (scratch_allocator != nullptr) {
2345     SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> a_bytes,
2346                         scratch_allocator->AllocateBytes(size));
2347     SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> b_bytes,
2348                         scratch_allocator->AllocateBytes(size));
2349     SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> c_bytes,
2350                         scratch_allocator->AllocateBytes(size));
2351     a = DeviceMemory<CUDA_T *>(a_bytes);
2352     b = DeviceMemory<CUDA_T *>(b_bytes);
2353     c = DeviceMemory<CUDA_T *>(c_bytes);
2354   } else {
2355     SE_ASSIGN_OR_RETURN(a_temporary,
2356                         stream->AllocateTemporaryArray<CUDA_T *>(batch_count));
2357     SE_ASSIGN_OR_RETURN(b_temporary,
2358                         stream->AllocateTemporaryArray<CUDA_T *>(batch_count));
2359     SE_ASSIGN_OR_RETURN(c_temporary,
2360                         stream->AllocateTemporaryArray<CUDA_T *>(batch_count));
2361     a = DeviceMemory<CUDA_T *>(*a_temporary->mutable_device_memory());
2362     b = DeviceMemory<CUDA_T *>(*b_temporary->mutable_device_memory());
2363     c = DeviceMemory<CUDA_T *>(*c_temporary->mutable_device_memory());
2364   }
2365 
2366   if (!stream->ThenMemcpy(&a, a_raw_ptrs.data(), size).ok() ||
2367       !stream->ThenMemcpy(&b, b_raw_ptrs.data(), size).ok() ||
2368       !stream->ThenMemcpy(&c, c_raw_ptrs.data(), size).ok()) {
2369     return port::Status(port::error::INTERNAL,
2370                         "failed to copy memory from host to device in "
2371                         "CUDABlas::DoBlasGemmBatched");
2372   }
2373 
2374   cudaDataType_t data_type = CUDADataType<T>::type;
2375 
2376 #if CUDA_VERSION >= 9010
2377   if (stream->GetCudaComputeCapability().IsAtLeast(5)) {
2378     cublasMath_t math_type;
2379     cublasGemmAlgo_t algo;
2380     if (data_type == CUDA_R_16F) {
2381 #if CUDA_VERSION < 11000
2382       math_type = CUBLAS_TENSOR_OP_MATH;
2383 #else
2384       math_type = CUBLAS_DEFAULT_MATH;
2385 #endif
2386       algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
2387 #if CUBLAS_VER_MAJOR >= 11
2388     } else if (data_type == CUDA_R_32F) {
2389       // DoBlassInternalImpl will switch math_type back to CUBLAS_DEFAULT_MATH
2390       // if TensorFloat-32 is disabled.
2391       math_type = CUBLAS_TF32_TENSOR_OP_MATH;
2392       algo = tensorflow::tensor_float_32_execution_enabled()
2393                  ? CUBLAS_GEMM_DFALT_TENSOR_OP
2394                  : CUBLAS_GEMM_DFALT;
2395 #endif
2396     } else {
2397       math_type = CUBLAS_DEFAULT_MATH;
2398       algo = CUBLAS_GEMM_DFALT;
2399     }
2400     cudaDataType_t compute_type =
2401         (data_type == CUDA_R_16F ? CUDA_R_32F : data_type);
2402     const void **a_void_ptrs = reinterpret_cast<const void **>(
2403         const_cast<const CUDA_T **>(GpuMemory(a)));
2404     const void **b_void_ptrs = reinterpret_cast<const void **>(
2405         const_cast<const CUDA_T **>(GpuMemory(b)));
2406     void **c_void_ptrs =
2407         reinterpret_cast<void **>(const_cast<CUDA_T **>(GpuMemory(c)));
2408     return DoBlasInternalImpl(
2409         AS_LAMBDA(cublasGemmBatchedEx), stream, true /* = pointer_mode_host */,
2410         math_type, CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n,
2411         k, &alpha, a_void_ptrs, data_type, lda, b_void_ptrs, data_type, ldb,
2412         &beta, c_void_ptrs, data_type, ldc, batch_count, compute_type, algo);
2413   }
2414 #endif
2415   // either CUDA_VERSION < 9.1 or SM < 5.0
2416   if (data_type != CUDA_R_16F) {
2417     auto cb_alpha = GpuComplexValue(alpha);
2418     auto cb_beta = GpuComplexValue(beta);
2419     bool ok = DoBlasInternal(
2420         cublas_func, stream, true /* = pointer_mode_host */,
2421         CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
2422         GpuComplex(&cb_alpha), const_cast<const CUDA_T **>(GpuMemory(a)), lda,
2423         const_cast<const CUDA_T **>(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
2424         const_cast<CUDA_T **>(GpuMemory(c)), ldc, batch_count);
2425     if (ok) {
2426       return port::Status::OK();
2427     }
2428     return port::Status(port::error::INTERNAL,
2429                         "failed BLAS call, see log for details");
2430   } else {
2431     // Fall back to a loop for fp16
2432     for (int b = 0; b < batch_count; ++b) {
2433       const DeviceMemory<T> &a_matrix = *a_ptrs_to_wrappers[b];
2434       const DeviceMemory<T> &b_matrix = *b_ptrs_to_wrappers[b];
2435       DeviceMemory<T> *c_matrix = c_ptrs_to_wrappers[b];
2436       TF_RETURN_IF_ERROR(DoBlasGemm(
2437           stream, transa, transb, m, n, k, blas::ToDataType<T>::value, &alpha,
2438           a_matrix, lda, b_matrix, ldb, &beta, c_matrix, ldc));
2439     }
2440     return port::Status::OK();
2441   }
2442 }
2443 
2444 bool CUDABlas::DoBlasGemmBatched(
2445     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2446     uint64 n, uint64 k, float alpha,
2447     const port::ArraySlice<DeviceMemory<Eigen::half> *> &a_array, int lda,
2448     const port::ArraySlice<DeviceMemory<Eigen::half> *> &b_array, int ldb,
2449     float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c_array,
2450     int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
2451   // Note: The func passed here (cublasSgemmBatched) is not actually called,
2452   // due to special handling of fp16 inside DoBlasGemmBatchedInternal.
2453   port::Status status = DoBlasGemmBatchedInternal(
2454       cublasSgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
2455       b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
2456   if (!status.ok()) {
2457     LOG(ERROR) << status;
2458   }
2459   return status.ok();
2460 }
2461 
2462 bool CUDABlas::DoBlasGemmBatched(
2463     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2464     uint64 n, uint64 k, float alpha,
2465     const port::ArraySlice<DeviceMemory<float> *> &a_array, int lda,
2466     const port::ArraySlice<DeviceMemory<float> *> &b_array, int ldb, float beta,
2467     const port::ArraySlice<DeviceMemory<float> *> &c_array, int ldc,
2468     int batch_count, ScratchAllocator *scratch_allocator) {
2469   port::Status status = DoBlasGemmBatchedInternal(
2470       cublasSgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
2471       b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
2472   if (!status.ok()) {
2473     LOG(ERROR) << status;
2474   }
2475   return status.ok();
2476 }
2477 
2478 bool CUDABlas::DoBlasGemmBatched(
2479     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2480     uint64 n, uint64 k, double alpha,
2481     const port::ArraySlice<DeviceMemory<double> *> &a_array, int lda,
2482     const port::ArraySlice<DeviceMemory<double> *> &b_array, int ldb,
2483     double beta, const port::ArraySlice<DeviceMemory<double> *> &c_array,
2484     int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
2485   port::Status status = DoBlasGemmBatchedInternal(
2486       cublasDgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
2487       b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
2488   if (!status.ok()) {
2489     LOG(ERROR) << status;
2490   }
2491   return status.ok();
2492 }
2493 
2494 bool CUDABlas::DoBlasGemmBatched(
2495     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2496     uint64 n, uint64 k, std::complex<float> alpha,
2497     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a_array,
2498     int lda,
2499     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b_array,
2500     int ldb, std::complex<float> beta,
2501     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c_array,
2502     int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
2503   port::Status status = DoBlasGemmBatchedInternal(
2504       cublasCgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
2505       b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
2506   if (!status.ok()) {
2507     LOG(ERROR) << status;
2508   }
2509   return status.ok();
2510 }
2511 
2512 bool CUDABlas::DoBlasGemmBatched(
2513     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2514     uint64 n, uint64 k, std::complex<double> alpha,
2515     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a_array,
2516     int lda,
2517     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b_array,
2518     int ldb, std::complex<double> beta,
2519     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c_array,
2520     int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
2521   port::Status status = DoBlasGemmBatchedInternal(
2522       cublasZgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
2523       b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
2524   if (!status.ok()) {
2525     LOG(ERROR) << status;
2526   }
2527   return status.ok();
2528 }
2529 
2530 port::Status CUDABlas::DoBlasGemmStridedBatched(
2531     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2532     uint64 n, uint64 k, blas::DataType dtype, const void *alpha,
2533     const DeviceMemoryBase &a, int lda, int64_t stride_a,
2534     const DeviceMemoryBase &b, int ldb, int64_t stride_b, const void *beta,
2535     DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count) {
2536   cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
2537 #if CUDA_VERSION < 11000
2538   if (dtype == dnn::kHalf) {
2539     math_type = CUBLAS_TENSOR_OP_MATH;
2540   }
2541 #else
2542   if (dtype == dnn::kFloat) {
2543     math_type = CUBLAS_TF32_TENSOR_OP_MATH;
2544   }
2545 #endif
2546 
2547   switch (dtype) {
2548 #if CUDA_VERSION >= 11000
2549     case dnn::kBF16: {
2550       CudaComputeCapability cc = stream->GetCudaComputeCapability();
2551       if (cc.IsAtLeast(7)) {
2552         cublasGemmAlgo_t algo =
2553             (cc.major >= 7 ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT);
2554         return DoBlasInternalImpl(
2555             AS_LAMBDA(cublasGemmStridedBatchedEx), stream,
2556             true /* = pointer_mode_host */, math_type,
2557             CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
2558             alpha, a.opaque(), CUDA_R_16BF, lda, stride_a, b.opaque(),
2559             CUDA_R_16BF, ldb, stride_b, beta, c->opaque(), CUDA_R_16BF, ldc,
2560             stride_c, batch_count,
2561             /*compute_type=*/CUDA_R_32F, algo);
2562       }
2563       // Fall back to a loop.
2564       for (int batch = 0; batch < batch_count; ++batch) {
2565         const auto *a_matrix = reinterpret_cast<const __nv_bfloat16 *>(
2566             static_cast<const Eigen::bfloat16 *>(a.opaque()) +
2567             batch * stride_a);
2568         const auto *b_matrix = reinterpret_cast<const __nv_bfloat16 *>(
2569             static_cast<const Eigen::bfloat16 *>(b.opaque()) +
2570             batch * stride_b);
2571         auto *c_matrix = reinterpret_cast<__nv_bfloat16 *>(
2572             static_cast<Eigen::bfloat16 *>(c->opaque()) + batch * stride_c);
2573         TF_RETURN_IF_ERROR(DoBlasInternalImpl(
2574             cublasSgemmEx, stream, true /* = pointer_mode_host */,
2575             CUBLAS_DEFAULT_MATH, CUDABlasTranspose(transa),
2576             CUDABlasTranspose(transb), m, n, k,
2577             static_cast<const float *>(alpha), a_matrix, CUDA_R_16BF, lda,
2578             b_matrix, CUDA_R_16BF, ldb, static_cast<const float *>(beta),
2579             c_matrix, CUDA_R_16BF, ldc));
2580       }
2581       return port::Status::OK();
2582     }
2583 #endif
2584     case dnn::kHalf: {
2585 #if CUDA_VERSION >= 9010
2586       CudaComputeCapability cc = stream->GetCudaComputeCapability();
2587       if (cc.major >= 5) {
2588         cublasGemmAlgo_t algo =
2589             (cc.major >= 7 ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT);
2590         return DoBlasInternalImpl(
2591             AS_LAMBDA(cublasGemmStridedBatchedEx), stream,
2592             true /* = pointer_mode_host */, math_type,
2593             CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
2594             alpha, a.opaque(), CUDA_R_16F, lda, stride_a, b.opaque(),
2595             CUDA_R_16F, ldb, stride_b, beta, c->opaque(), CUDA_R_16F, ldc,
2596             stride_c, batch_count, CUDA_R_32F, algo);
2597       }
2598 #endif
2599       // Either CUDA_VERSION < 9.1 or SM < 5.0. Fall back to a loop.
2600       for (int batch = 0; batch < batch_count; ++batch) {
2601         const auto *a_matrix = reinterpret_cast<const __half *>(
2602             static_cast<const Eigen::half *>(a.opaque()) + batch * stride_a);
2603         const auto *b_matrix = reinterpret_cast<const __half *>(
2604             static_cast<const Eigen::half *>(b.opaque()) + batch * stride_b);
2605         auto *c_matrix = reinterpret_cast<__half *>(
2606             static_cast<Eigen::half *>(c->opaque()) + batch * stride_c);
2607         TF_RETURN_IF_ERROR(DoBlasInternalImpl(
2608             cublasSgemmEx, stream, true /* = pointer_mode_host */,
2609             CUBLAS_DEFAULT_MATH, CUDABlasTranspose(transa),
2610             CUDABlasTranspose(transb), m, n, k,
2611             static_cast<const float *>(alpha), a_matrix, SE_CUDA_DATA_HALF, lda,
2612             b_matrix, SE_CUDA_DATA_HALF, ldb, static_cast<const float *>(beta),
2613             c_matrix, SE_CUDA_DATA_HALF, ldc));
2614       }
2615       return port::Status::OK();
2616     }
2617     case dnn::kFloat: {
2618       return DoBlasInternalImpl(
2619           cublasSgemmStridedBatched, stream, true /* = pointer_mode_host */,
2620           math_type, CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n,
2621           k, static_cast<const float *>(alpha),
2622           static_cast<const float *>(a.opaque()), lda, stride_a,
2623           static_cast<const float *>(b.opaque()), ldb, stride_b,
2624           static_cast<const float *>(beta), static_cast<float *>(c->opaque()),
2625           ldc, stride_c, batch_count);
2626     }
2627     case dnn::kDouble:
2628       return DoBlasInternalImpl(
2629           cublasDgemmStridedBatched, stream, true /* = pointer_mode_host */,
2630           math_type, CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n,
2631           k, static_cast<const double *>(alpha),
2632           static_cast<const double *>(a.opaque()), lda, stride_a,
2633           static_cast<const double *>(b.opaque()), ldb, stride_b,
2634           static_cast<const double *>(beta), static_cast<double *>(c->opaque()),
2635           ldc, stride_c, batch_count);
2636     case dnn::kComplexFloat: {
2637       GpuComplexType cb_alpha =
2638           GpuComplexValue(*static_cast<const std::complex<float> *>(alpha));
2639       GpuComplexType cb_beta =
2640           GpuComplexValue(*static_cast<const std::complex<float> *>(beta));
2641       return DoBlasInternalImpl(
2642           cublasCgemmStridedBatched, stream, true /* = pointer_mode_host */,
2643           math_type, CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n,
2644           k, GpuComplex(&cb_alpha),
2645           static_cast<const GpuComplexType *>(a.opaque()), lda, stride_a,
2646           static_cast<const GpuComplexType *>(b.opaque()), ldb, stride_b,
2647           GpuComplex(&cb_beta), static_cast<GpuComplexType *>(c->opaque()), ldc,
2648           stride_c, batch_count);
2649     }
2650     case dnn::kComplexDouble: {
2651       GpuDoubleComplexType cb_alpha =
2652           GpuComplexValue(*static_cast<const std::complex<double> *>(alpha));
2653       GpuDoubleComplexType cb_beta =
2654           GpuComplexValue(*static_cast<const std::complex<double> *>(beta));
2655       return DoBlasInternalImpl(
2656           cublasZgemmStridedBatched, stream, true /* = pointer_mode_host */,
2657           math_type, CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n,
2658           k, GpuComplex(&cb_alpha),
2659           static_cast<const GpuDoubleComplexType *>(a.opaque()), lda, stride_a,
2660           static_cast<const GpuDoubleComplexType *>(b.opaque()), ldb, stride_b,
2661           GpuComplex(&cb_beta),
2662           static_cast<GpuDoubleComplexType *>(c->opaque()), ldc, stride_c,
2663           batch_count);
2664     }
2665     default:
2666       return port::InternalError(absl::StrCat("Unsupported datatype for GEMM: ",
2667                                               blas::DataTypeString(dtype)));
2668   }
2669 }
2670 
2671 bool CUDABlas::DoBlasHemm(Stream *stream, blas::Side side,
2672                           blas::UpperLower uplo, uint64 m, uint64 n,
2673                           std::complex<float> alpha,
2674                           const DeviceMemory<std::complex<float>> &a, int lda,
2675                           const DeviceMemory<std::complex<float>> &b, int ldb,
2676                           std::complex<float> beta,
2677                           DeviceMemory<std::complex<float>> *c, int ldc) {
2678   auto cb_alpha = GpuComplexValue(alpha);
2679   auto cb_beta = GpuComplexValue(beta);
2680   return DoBlasInternal(cublasChemm, stream, true /* = pointer_mode_host */,
2681                         CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n,
2682                         GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2683                         GpuComplex(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
2684                         GpuComplex(GpuMemoryMutable(c)), ldc);
2685 }
2686 
2687 bool CUDABlas::DoBlasHemm(Stream *stream, blas::Side side,
2688                           blas::UpperLower uplo, uint64 m, uint64 n,
2689                           std::complex<double> alpha,
2690                           const DeviceMemory<std::complex<double>> &a, int lda,
2691                           const DeviceMemory<std::complex<double>> &b, int ldb,
2692                           std::complex<double> beta,
2693                           DeviceMemory<std::complex<double>> *c, int ldc) {
2694   auto cb_alpha = GpuComplexValue(alpha);
2695   auto cb_beta = GpuComplexValue(beta);
2696   return DoBlasInternal(cublasZhemm, stream, true /* = pointer_mode_host */,
2697                         CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n,
2698                         GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2699                         GpuComplex(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
2700                         GpuComplex(GpuMemoryMutable(c)), ldc);
2701 }
2702 
2703 bool CUDABlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo,
2704                           blas::Transpose trans, uint64 n, uint64 k,
2705                           float alpha,
2706                           const DeviceMemory<std::complex<float>> &a, int lda,
2707                           float beta, DeviceMemory<std::complex<float>> *c,
2708                           int ldc) {
2709   return DoBlasInternal(cublasCherk, stream, true /* = pointer_mode_host */,
2710                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2711                         k, &alpha, GpuComplex(GpuMemory(a)), lda, &beta,
2712                         GpuComplex(GpuMemoryMutable(c)), ldc);
2713 }
2714 
2715 bool CUDABlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo,
2716                           blas::Transpose trans, uint64 n, uint64 k,
2717                           double alpha,
2718                           const DeviceMemory<std::complex<double>> &a, int lda,
2719                           double beta, DeviceMemory<std::complex<double>> *c,
2720                           int ldc) {
2721   return DoBlasInternal(cublasZherk, stream, true /* = pointer_mode_host */,
2722                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2723                         k, &alpha, GpuComplex(GpuMemory(a)), lda, &beta,
2724                         GpuComplex(GpuMemoryMutable(c)), ldc);
2725 }
2726 
2727 bool CUDABlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
2728                            blas::Transpose trans, uint64 n, uint64 k,
2729                            std::complex<float> alpha,
2730                            const DeviceMemory<std::complex<float>> &a, int lda,
2731                            const DeviceMemory<std::complex<float>> &b, int ldb,
2732                            float beta, DeviceMemory<std::complex<float>> *c,
2733                            int ldc) {
2734   auto cb_alpha = GpuComplexValue(alpha);
2735   return DoBlasInternal(cublasCher2k, stream, true /* = pointer_mode_host */,
2736                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2737                         k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2738                         GpuComplex(GpuMemory(b)), ldb, &beta,
2739                         GpuComplex(GpuMemoryMutable(c)), ldc);
2740 }
2741 
2742 bool CUDABlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
2743                            blas::Transpose trans, uint64 n, uint64 k,
2744                            std::complex<double> alpha,
2745                            const DeviceMemory<std::complex<double>> &a, int lda,
2746                            const DeviceMemory<std::complex<double>> &b, int ldb,
2747                            double beta, DeviceMemory<std::complex<double>> *c,
2748                            int ldc) {
2749   auto cb_alpha = GpuComplexValue(alpha);
2750   return DoBlasInternal(cublasZher2k, stream, true /* = pointer_mode_host */,
2751                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2752                         k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2753                         GpuComplex(GpuMemory(b)), ldb, &beta,
2754                         GpuComplex(GpuMemoryMutable(c)), ldc);
2755 }
2756 
2757 bool CUDABlas::DoBlasSymm(Stream *stream, blas::Side side,
2758                           blas::UpperLower uplo, uint64 m, uint64 n,
2759                           float alpha, const DeviceMemory<float> &a, int lda,
2760                           const DeviceMemory<float> &b, int ldb, float beta,
2761                           DeviceMemory<float> *c, int ldc) {
2762   return DoBlasInternal(cublasSsymm, stream, true /* = pointer_mode_host */,
2763                         CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n,
2764                         &alpha, GpuMemory(a), lda, GpuMemory(b), ldb, &beta,
2765                         GpuMemoryMutable(c), ldc);
2766 }
2767 
2768 bool CUDABlas::DoBlasSymm(Stream *stream, blas::Side side,
2769                           blas::UpperLower uplo, uint64 m, uint64 n,
2770                           double alpha, const DeviceMemory<double> &a, int lda,
2771                           const DeviceMemory<double> &b, int ldb, double beta,
2772                           DeviceMemory<double> *c, int ldc) {
2773   return DoBlasInternal(cublasDsymm, stream, true /* = pointer_mode_host */,
2774                         CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n,
2775                         &alpha, GpuMemory(a), lda, GpuMemory(b), ldb, &beta,
2776                         GpuMemoryMutable(c), ldc);
2777 }
2778 
2779 bool CUDABlas::DoBlasSymm(Stream *stream, blas::Side side,
2780                           blas::UpperLower uplo, uint64 m, uint64 n,
2781                           std::complex<float> alpha,
2782                           const DeviceMemory<std::complex<float>> &a, int lda,
2783                           const DeviceMemory<std::complex<float>> &b, int ldb,
2784                           std::complex<float> beta,
2785                           DeviceMemory<std::complex<float>> *c, int ldc) {
2786   auto cb_alpha = GpuComplexValue(alpha);
2787   auto cb_beta = GpuComplexValue(beta);
2788   return DoBlasInternal(cublasCsymm, stream, true /* = pointer_mode_host */,
2789                         CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n,
2790                         GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2791                         GpuComplex(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
2792                         GpuComplex(GpuMemoryMutable(c)), ldc);
2793 }
2794 
2795 bool CUDABlas::DoBlasSymm(Stream *stream, blas::Side side,
2796                           blas::UpperLower uplo, uint64 m, uint64 n,
2797                           std::complex<double> alpha,
2798                           const DeviceMemory<std::complex<double>> &a, int lda,
2799                           const DeviceMemory<std::complex<double>> &b, int ldb,
2800                           std::complex<double> beta,
2801                           DeviceMemory<std::complex<double>> *c, int ldc) {
2802   auto cb_alpha = GpuComplexValue(alpha);
2803   auto cb_beta = GpuComplexValue(beta);
2804   return DoBlasInternal(cublasZsymm, stream, true /* = pointer_mode_host */,
2805                         CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n,
2806                         GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2807                         GpuComplex(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
2808                         GpuComplex(GpuMemoryMutable(c)), ldc);
2809 }
2810 
2811 bool CUDABlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
2812                           blas::Transpose trans, uint64 n, uint64 k,
2813                           float alpha, const DeviceMemory<float> &a, int lda,
2814                           float beta, DeviceMemory<float> *c, int ldc) {
2815   return DoBlasInternal(cublasSsyrk, stream, true /* = pointer_mode_host */,
2816                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2817                         k, &alpha, GpuMemory(a), lda, &beta,
2818                         GpuMemoryMutable(c), ldc);
2819 }
2820 
2821 bool CUDABlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
2822                           blas::Transpose trans, uint64 n, uint64 k,
2823                           double alpha, const DeviceMemory<double> &a, int lda,
2824                           double beta, DeviceMemory<double> *c, int ldc) {
2825   return DoBlasInternal(cublasDsyrk, stream, true /* = pointer_mode_host */,
2826                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2827                         k, &alpha, GpuMemory(a), lda, &beta,
2828                         GpuMemoryMutable(c), ldc);
2829 }
2830 
2831 bool CUDABlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
2832                           blas::Transpose trans, uint64 n, uint64 k,
2833                           std::complex<float> alpha,
2834                           const DeviceMemory<std::complex<float>> &a, int lda,
2835                           std::complex<float> beta,
2836                           DeviceMemory<std::complex<float>> *c, int ldc) {
2837   auto cb_alpha = GpuComplexValue(alpha);
2838   auto cb_beta = GpuComplexValue(beta);
2839   return DoBlasInternal(cublasCsyrk, stream, true /* = pointer_mode_host */,
2840                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2841                         k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2842                         GpuComplex(&cb_beta), GpuComplex(GpuMemoryMutable(c)),
2843                         ldc);
2844 }
2845 
2846 bool CUDABlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
2847                           blas::Transpose trans, uint64 n, uint64 k,
2848                           std::complex<double> alpha,
2849                           const DeviceMemory<std::complex<double>> &a, int lda,
2850                           std::complex<double> beta,
2851                           DeviceMemory<std::complex<double>> *c, int ldc) {
2852   auto cb_alpha = GpuComplexValue(alpha);
2853   auto cb_beta = GpuComplexValue(beta);
2854   return DoBlasInternal(cublasZsyrk, stream, true /* = pointer_mode_host */,
2855                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2856                         k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2857                         GpuComplex(&cb_beta), GpuComplex(GpuMemoryMutable(c)),
2858                         ldc);
2859 }
2860 
2861 bool CUDABlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
2862                            blas::Transpose trans, uint64 n, uint64 k,
2863                            float alpha, const DeviceMemory<float> &a, int lda,
2864                            const DeviceMemory<float> &b, int ldb, float beta,
2865                            DeviceMemory<float> *c, int ldc) {
2866   return DoBlasInternal(cublasSsyr2k, stream, true /* = pointer_mode_host */,
2867                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2868                         k, &alpha, GpuMemory(a), lda, GpuMemory(b), ldb, &beta,
2869                         GpuMemoryMutable(c), ldc);
2870 }
2871 
2872 bool CUDABlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
2873                            blas::Transpose trans, uint64 n, uint64 k,
2874                            double alpha, const DeviceMemory<double> &a, int lda,
2875                            const DeviceMemory<double> &b, int ldb, double beta,
2876                            DeviceMemory<double> *c, int ldc) {
2877   return DoBlasInternal(cublasDsyr2k, stream, true /* = pointer_mode_host */,
2878                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2879                         k, &alpha, GpuMemory(a), lda, GpuMemory(b), ldb, &beta,
2880                         GpuMemoryMutable(c), ldc);
2881 }
2882 
2883 bool CUDABlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
2884                            blas::Transpose trans, uint64 n, uint64 k,
2885                            std::complex<float> alpha,
2886                            const DeviceMemory<std::complex<float>> &a, int lda,
2887                            const DeviceMemory<std::complex<float>> &b, int ldb,
2888                            std::complex<float> beta,
2889                            DeviceMemory<std::complex<float>> *c, int ldc) {
2890   auto cb_alpha = GpuComplexValue(alpha);
2891   auto cb_beta = GpuComplexValue(beta);
2892   return DoBlasInternal(cublasCsyr2k, stream, true /* = pointer_mode_host */,
2893                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2894                         k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2895                         GpuComplex(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
2896                         GpuComplex(GpuMemoryMutable(c)), ldc);
2897 }
2898 
2899 bool CUDABlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
2900                            blas::Transpose trans, uint64 n, uint64 k,
2901                            std::complex<double> alpha,
2902                            const DeviceMemory<std::complex<double>> &a, int lda,
2903                            const DeviceMemory<std::complex<double>> &b, int ldb,
2904                            std::complex<double> beta,
2905                            DeviceMemory<std::complex<double>> *c, int ldc) {
2906   auto cb_alpha = GpuComplexValue(alpha);
2907   auto cb_beta = GpuComplexValue(beta);
2908   return DoBlasInternal(cublasZsyr2k, stream, true /* = pointer_mode_host */,
2909                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2910                         k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2911                         GpuComplex(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
2912                         GpuComplex(GpuMemoryMutable(c)), ldc);
2913 }
2914 
2915 bool CUDABlas::DoBlasTrmm(Stream *stream, blas::Side side,
2916                           blas::UpperLower uplo, blas::Transpose transa,
2917                           blas::Diagonal diag, uint64 m, uint64 n, float alpha,
2918                           const DeviceMemory<float> &a, int lda,
2919                           DeviceMemory<float> *b, int ldb) {
2920   return DoBlasInternal(cublasStrmm, stream, true /* = pointer_mode_host */,
2921                         CUDABlasSide(side), CUDABlasUpperLower(uplo),
2922                         CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
2923                         &alpha, GpuMemory(a), lda, GpuMemoryMutable(b), ldb,
2924                         GpuMemoryMutable(b), ldb);
2925 }
2926 
2927 bool CUDABlas::DoBlasTrmm(Stream *stream, blas::Side side,
2928                           blas::UpperLower uplo, blas::Transpose transa,
2929                           blas::Diagonal diag, uint64 m, uint64 n, double alpha,
2930                           const DeviceMemory<double> &a, int lda,
2931                           DeviceMemory<double> *b, int ldb) {
2932   return DoBlasInternal(cublasDtrmm, stream, true /* = pointer_mode_host */,
2933                         CUDABlasSide(side), CUDABlasUpperLower(uplo),
2934                         CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
2935                         &alpha, GpuMemory(a), lda, GpuMemoryMutable(b), ldb,
2936                         GpuMemoryMutable(b), ldb);
2937 }
2938 
2939 bool CUDABlas::DoBlasTrmm(Stream *stream, blas::Side side,
2940                           blas::UpperLower uplo, blas::Transpose transa,
2941                           blas::Diagonal diag, uint64 m, uint64 n,
2942                           std::complex<float> alpha,
2943                           const DeviceMemory<std::complex<float>> &a, int lda,
2944                           DeviceMemory<std::complex<float>> *b, int ldb) {
2945   auto cb_alpha = GpuComplexValue(alpha);
2946   return DoBlasInternal(cublasCtrmm, stream, true /* = pointer_mode_host */,
2947                         CUDABlasSide(side), CUDABlasUpperLower(uplo),
2948                         CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
2949                         GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2950                         GpuComplex(GpuMemoryMutable(b)), ldb,
2951                         GpuComplex(GpuMemoryMutable(b)), ldb);
2952 }
2953 
2954 bool CUDABlas::DoBlasTrmm(Stream *stream, blas::Side side,
2955                           blas::UpperLower uplo, blas::Transpose transa,
2956                           blas::Diagonal diag, uint64 m, uint64 n,
2957                           std::complex<double> alpha,
2958                           const DeviceMemory<std::complex<double>> &a, int lda,
2959                           DeviceMemory<std::complex<double>> *b, int ldb) {
2960   auto cb_alpha = GpuComplexValue(alpha);
2961   return DoBlasInternal(cublasZtrmm, stream, true /* = pointer_mode_host */,
2962                         CUDABlasSide(side), CUDABlasUpperLower(uplo),
2963                         CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
2964                         GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2965                         GpuComplex(GpuMemoryMutable(b)), ldb,
2966                         GpuComplex(GpuMemoryMutable(b)), ldb);
2967 }
2968 
2969 bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
2970                           blas::UpperLower uplo, blas::Transpose transa,
2971                           blas::Diagonal diag, uint64 m, uint64 n, float alpha,
2972                           const DeviceMemory<float> &a, int lda,
2973                           DeviceMemory<float> *b, int ldb) {
2974   return DoBlasInternal(cublasStrsm, stream, true /* = pointer_mode_host */,
2975                         CUDABlasSide(side), CUDABlasUpperLower(uplo),
2976                         CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
2977                         &alpha, GpuMemory(a), lda, GpuMemoryMutable(b), ldb);
2978 }
2979 
2980 bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
2981                           blas::UpperLower uplo, blas::Transpose transa,
2982                           blas::Diagonal diag, uint64 m, uint64 n, double alpha,
2983                           const DeviceMemory<double> &a, int lda,
2984                           DeviceMemory<double> *b, int ldb) {
2985   return DoBlasInternal(cublasDtrsm, stream, true /* = pointer_mode_host */,
2986                         CUDABlasSide(side), CUDABlasUpperLower(uplo),
2987                         CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
2988                         &alpha, GpuMemory(a), lda, GpuMemoryMutable(b), ldb);
2989 }
2990 
2991 bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
2992                           blas::UpperLower uplo, blas::Transpose transa,
2993                           blas::Diagonal diag, uint64 m, uint64 n,
2994                           std::complex<float> alpha,
2995                           const DeviceMemory<std::complex<float>> &a, int lda,
2996                           DeviceMemory<std::complex<float>> *b, int ldb) {
2997   auto cb_alpha = GpuComplexValue(alpha);
2998   return DoBlasInternal(cublasCtrsm, stream, true /* = pointer_mode_host */,
2999                         CUDABlasSide(side), CUDABlasUpperLower(uplo),
3000                         CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
3001                         GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
3002                         GpuComplex(GpuMemoryMutable(b)), ldb);
3003 }
3004 
3005 bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
3006                           blas::UpperLower uplo, blas::Transpose transa,
3007                           blas::Diagonal diag, uint64 m, uint64 n,
3008                           std::complex<double> alpha,
3009                           const DeviceMemory<std::complex<double>> &a, int lda,
3010                           DeviceMemory<std::complex<double>> *b, int ldb) {
3011   auto cb_alpha = GpuComplexValue(alpha);
3012   return DoBlasInternal(cublasZtrsm, stream, true /* = pointer_mode_host */,
3013                         CUDABlasSide(side), CUDABlasUpperLower(uplo),
3014                         CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
3015                         GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
3016                         GpuComplex(GpuMemoryMutable(b)), ldb);
3017 }
3018 
3019 // We only use cublasLt from CUDA 11.0 onward.
3020 #if CUDA_VERSION >= 11000
3021 
3022 namespace {
3023 
3024 template <typename T>
3025 inline port::Status SetCublasLtAttr(cublasLtMatrixLayout_t handle,
3026                                     cublasLtMatrixLayoutAttribute_t attr,
3027                                     const T &value) {
3028   cublasStatus_t status =
3029       cublasLtMatrixLayoutSetAttribute(handle, attr, &value, sizeof(T));
3030   if (status != CUBLAS_STATUS_SUCCESS) {
3031     return port::Status(
3032         port::error::INTERNAL,
3033         absl::StrCat("cublasLtMatrixLayoutSetAttribute(attr=", attr,
3034                      ", value=", value, ") failed: ", ToString(status)));
3035   }
3036   return port::Status::OK();
3037 }
3038 
3039 template <typename T>
3040 inline port::Status SetCublasLtAttr(cublasLtMatmulAlgo_t *handle,
3041                                     cublasLtMatmulAlgoConfigAttributes_t attr,
3042                                     const T &value) {
3043   cublasStatus_t status =
3044       cublasLtMatmulAlgoConfigSetAttribute(handle, attr, &value, sizeof(T));
3045   if (status != CUBLAS_STATUS_SUCCESS) {
3046     return port::Status(
3047         port::error::INTERNAL,
3048         absl::StrCat("cublasLtMatmulAlgoConfigSetAttribute(attr=", attr,
3049                      ", value=", value, ") failed: ", ToString(status)));
3050   }
3051   return port::Status::OK();
3052 }
3053 
3054 template <typename T>
3055 inline port::Status SetCublasLtAttr(cublasLtMatmulPreference_t handle,
3056                                     cublasLtMatmulPreferenceAttributes_t attr,
3057                                     const T &value) {
3058   cublasStatus_t status =
3059       cublasLtMatmulPreferenceSetAttribute(handle, attr, &value, sizeof(value));
3060   if (status != CUBLAS_STATUS_SUCCESS) {
3061     return port::Status(
3062         port::error::INTERNAL,
3063         absl::StrCat("cublasLtMatmulPreferenceSetAttribute(attr=", attr,
3064                      ", value=", value, ") failed: ", ToString(status)));
3065   }
3066   return port::Status::OK();
3067 }
3068 
3069 template <typename T>
3070 inline bool GetCublasLtAttr(const cublasLtMatmulAlgo_t *handle,
3071                             cublasLtMatmulAlgoConfigAttributes_t attr,
3072                             T *value) {
3073   auto mutable_handle = const_cast<cublasLtMatmulAlgo_t *>(handle);
3074   size_t bytes_written = 0;
3075   return cublasLtMatmulAlgoConfigGetAttribute(mutable_handle, attr, value,
3076                                               sizeof(T), &bytes_written) ==
3077              CUBLAS_STATUS_SUCCESS &&
3078          bytes_written == sizeof(T);
3079 }
3080 
3081 template <typename T>
3082 inline const T &ValueForStrCat(const T &value) {
3083   return value;
3084 }
3085 template <typename T>
3086 inline absl::Hex ValueForStrCat(T *ptr) {
3087   return absl::Hex(reinterpret_cast<uintptr_t>(ptr));
3088 }
3089 
3090 template <typename T>
3091 inline port::Status SetCublasLtAttr(cublasLtMatmulDesc_t handle,
3092                                     cublasLtMatmulDescAttributes_t attr,
3093                                     const T &value) {
3094   cublasStatus_t status =
3095       cublasLtMatmulDescSetAttribute(handle, attr, &value, sizeof(value));
3096   if (status != CUBLAS_STATUS_SUCCESS) {
3097     return port::Status(
3098         port::error::INTERNAL,
3099         absl::StrCat("cublasLtMatmulDescSetAttribute(attr=", attr, ", value=",
3100                      ValueForStrCat(value), ") failed: ", ToString(status)));
3101   }
3102   return port::Status::OK();
3103 }
3104 
3105 struct MatmulDescDestroyer {
3106   void operator()(cublasLtMatmulDesc_t matmul_desc) const {
3107     cublasLtMatmulDescDestroy(matmul_desc);
3108   }
3109 };
3110 struct LayoutDestroyer {
3111   void operator()(cublasLtMatrixLayout_t layout) const {
3112     cublasLtMatrixLayoutDestroy(layout);
3113   }
3114 };
3115 struct MatmulPreferenceDestroyer {
3116   void operator()(cublasLtMatmulPreference_t matmul_pref) const {
3117     cublasLtMatmulPreferenceDestroy(matmul_pref);
3118   }
3119 };
3120 using UniqueOpDesc =
3121     std::unique_ptr<std::remove_pointer<cublasLtMatmulDesc_t>::type,
3122                     MatmulDescDestroyer>;
3123 using UniqueLayoutDesc =
3124     std::unique_ptr<std::remove_pointer<cublasLtMatrixLayout_t>::type,
3125                     LayoutDestroyer>;
3126 using UniqueMatmulPreference =
3127     std::unique_ptr<std::remove_pointer<cublasLtMatmulPreference_t>::type,
3128                     MatmulPreferenceDestroyer>;
3129 
3130 port::StatusOr<UniqueOpDesc> CreateCublasLtOperationDesc(
3131     blas::ComputationType computation_type, blas::DataType scale_type,
3132     blas::PointerMode pointer_mode, blas::Epilogue epilogue,
3133     blas::Transpose transa, blas::Transpose transb) {
3134   cublasLtMatmulDesc_t desc;
3135   cublasComputeType_t cublas_compute_type =
3136       CUBLASComputationType(computation_type);
3137   cudaDataType_t cuda_scale_type = GetCUDADataType(scale_type);
3138   cublasStatus_t status =
3139       cublasLtMatmulDescCreate(&desc, cublas_compute_type, cuda_scale_type);
3140   if (status != CUBLAS_STATUS_SUCCESS) {
3141     return port::Status(
3142         port::error::INTERNAL,
3143         absl::StrCat("cublasLtMatmulDescCreate(computation_type=",
3144                      computation_type, ") failed: ", ToString(status)));
3145   }
3146   UniqueOpDesc unique_desc(desc);
3147   SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_POINTER_MODE,
3148                                      CUBLASPointerMode(pointer_mode)));
3149   SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_EPILOGUE,
3150                                      CUBLASEpilogue(epilogue)));
3151   SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSA,
3152                                      CUDABlasTranspose(transa)));
3153   SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSB,
3154                                      CUDABlasTranspose(transb)));
3155   return unique_desc;
3156 }
3157 
3158 port::StatusOr<UniqueLayoutDesc> CreateCublasLtLayoutDesc(
3159     blas::DataType data_type, uint64 rows, uint64 cols, int64_t ld,
3160     int64_t stride, int batch_count) {
3161   cublasLtMatrixLayout_t desc;
3162   cublasStatus_t status = cublasLtMatrixLayoutCreate(
3163       &desc, GetCUDADataType(data_type), rows, cols, ld);
3164   if (status != CUBLAS_STATUS_SUCCESS) {
3165     return port::Status(
3166         port::error::INTERNAL,
3167         absl::StrCat("cublasLtMatrixLayoutCreate failed: ", ToString(status)));
3168   }
3169   UniqueLayoutDesc unique_desc(desc);
3170   SE_RETURN_IF_ERROR(
3171       SetCublasLtAttr(desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, batch_count));
3172   SE_RETURN_IF_ERROR(SetCublasLtAttr(
3173       desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stride));
3174   return unique_desc;
3175 }
3176 
3177 // Helper function to allocate workspace.
3178 port::Status AllocateWorkspace(void **workspace,
3179                                ScratchAllocator *scratch_allocator,
3180                                size_t num_bytes) {
3181   SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace_bytes,
3182                       scratch_allocator->AllocateBytes(num_bytes));
3183   *workspace = (void *)GpuMemoryMutable(&workspace_bytes);
3184   return port::Status::OK();
3185 }
3186 
3187 template <typename T>
3188 blas::ComputationType ToComputationType();
3189 template <>
3190 blas::ComputationType ToComputationType<Eigen::half>() {
3191   return blas::ComputationType::kF16;
3192 }
3193 template <>
3194 blas::ComputationType ToComputationType<float>() {
3195   return blas::ComputationType::kF32;
3196 }
3197 template <>
3198 blas::ComputationType ToComputationType<double>() {
3199   return blas::ComputationType::kF64;
3200 }
3201 template <>
3202 blas::ComputationType ToComputationType<std::complex<float>>() {
3203   return blas::ComputationType::kComplexF32;
3204 }
3205 template <>
3206 blas::ComputationType ToComputationType<std::complex<double>>() {
3207   return blas::ComputationType::kComplexF64;
3208 }
3209 
3210 class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
3211  public:
3212   port::Status init(const blas::BlasLtMatmulPlanParams &p) {
3213     params_ = p;
3214     scale_type_ = GetScaleType(p.c_type, p.computation_type);
3215     SE_ASSIGN_OR_RETURN(
3216         op_desc_,
3217         CreateCublasLtOperationDesc(
3218             p.computation_type, GetScaleType(p.c_type, p.computation_type),
3219             p.pointer_mode, p.epilogue, p.transa, p.transb));
3220     uint64 rows_a = p.transa == blas::Transpose::kNoTranspose ? p.m : p.k;
3221     uint64 cols_a = p.transa == blas::Transpose::kNoTranspose ? p.k : p.m;
3222     uint64 rows_b = p.transb == blas::Transpose::kNoTranspose ? p.k : p.n;
3223     uint64 cols_b = p.transb == blas::Transpose::kNoTranspose ? p.n : p.k;
3224     SE_ASSIGN_OR_RETURN(
3225         a_desc_, CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda,
3226                                           p.stride_a, capped_batch_count()));
3227     SE_ASSIGN_OR_RETURN(
3228         b_desc_, CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb,
3229                                           p.stride_b, capped_batch_count()));
3230     SE_ASSIGN_OR_RETURN(
3231         c_desc_, CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
3232                                           capped_batch_count()));
3233     SE_ASSIGN_OR_RETURN(
3234         d_desc_, CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
3235                                           capped_batch_count()));
3236     remainder_batch_count_ =
3237         p.batch_count > kMaxBatchCount ? p.batch_count % kMaxBatchCount : 0;
3238     if (remainder_batch_count_) {
3239       SE_ASSIGN_OR_RETURN(
3240           a_remainder_desc_,
3241           CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda, p.stride_a,
3242                                    remainder_batch_count_));
3243       SE_ASSIGN_OR_RETURN(
3244           b_remainder_desc_,
3245           CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb, p.stride_b,
3246                                    remainder_batch_count_));
3247       SE_ASSIGN_OR_RETURN(
3248           c_remainder_desc_,
3249           CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
3250                                    remainder_batch_count_));
3251       SE_ASSIGN_OR_RETURN(
3252           d_remainder_desc_,
3253           CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
3254                                    remainder_batch_count_));
3255     }
3256     return port::Status::OK();
3257   }
3258 
3259   cublasLtMatmulDesc_t op_desc() const { return op_desc_.get(); }
3260   cublasLtMatrixLayout_t a_desc() const { return a_desc_.get(); }
3261   cublasLtMatrixLayout_t b_desc() const { return b_desc_.get(); }
3262   cublasLtMatrixLayout_t c_desc() const { return c_desc_.get(); }
3263   cublasLtMatrixLayout_t d_desc() const { return d_desc_.get(); }
3264   cublasLtMatrixLayout_t a_remainder_desc() const {
3265     return a_remainder_desc_.get();
3266   }
3267   cublasLtMatrixLayout_t b_remainder_desc() const {
3268     return b_remainder_desc_.get();
3269   }
3270   cublasLtMatrixLayout_t c_remainder_desc() const {
3271     return c_remainder_desc_.get();
3272   }
3273   cublasLtMatrixLayout_t d_remainder_desc() const {
3274     return d_remainder_desc_.get();
3275   }
3276 
3277   const blas::BlasLtMatmulPlanParams &params() const { return params_; }
3278   blas::DataType scale_type() const { return scale_type_; }
3279   blas::DataType ab_type() const override { return params_.ab_type; }
3280   blas::DataType c_type() const override { return params_.c_type; }
3281   int capped_batch_count() const {
3282     return std::min(params_.batch_count, kMaxBatchCount);
3283   }
3284   int remainder_batch_count() const { return remainder_batch_count_; }
3285 
3286   // Note: Must be const to satisfy API. This is always called before the plan
3287   // is executed, so the state change is not observed in subsequent executions.
3288   bool SetBiasPointer(const void *bias) const;
3289 
3290  private:
3291   // In some cases cublasLt does not support large batch sizes, so we need to
3292   // split up such cases into multiple calls.
3293   static constexpr int kMaxBatchCount = 65535;
3294   blas::BlasLtMatmulPlanParams params_;
3295   blas::DataType scale_type_;
3296   UniqueOpDesc op_desc_;
3297   // These have batch count set to capped_batch_count().
3298   UniqueLayoutDesc a_desc_;
3299   UniqueLayoutDesc b_desc_;
3300   UniqueLayoutDesc c_desc_;
3301   UniqueLayoutDesc d_desc_;
3302   int remainder_batch_count_;
3303   // These have batch count set to remainder_batch_count_, and are only created
3304   // if params_.batch_count > kMaxBatchSize.
3305   UniqueLayoutDesc a_remainder_desc_;
3306   UniqueLayoutDesc b_remainder_desc_;
3307   UniqueLayoutDesc c_remainder_desc_;
3308   UniqueLayoutDesc d_remainder_desc_;
3309 };
3310 
3311 /*static*/ constexpr int CUDABlasLtMatmulPlan::kMaxBatchCount;
3312 
3313 bool CUDABlasLtMatmulPlan::SetBiasPointer(const void *bias) const {
3314   return SetCublasLtAttr(op_desc_.get(), CUBLASLT_MATMUL_DESC_BIAS_POINTER,
3315                          bias)
3316       .ok();
3317 }
3318 
3319 class CUDABlasLtMatmulAlgorithm final : public blas::IBlasLtMatmulAlgorithm {
3320  public:
3321   CUDABlasLtMatmulAlgorithm(blas::AlgorithmType index,
3322                             cublasLtMatmulAlgo_t algo, size_t workspace_size)
3323       : index_(index), algo_(algo), workspace_size_(workspace_size) {}
3324 
3325   blas::AlgorithmType index() const override { return index_; }
3326 
3327   size_t workspace_size() const override { return workspace_size_; }
3328 
3329   const cublasLtMatmulAlgo_t *algo() const { return &algo_; }
3330 
3331   int algo_id() const {
3332     int id;
3333     GetCublasLtAttr(&algo_, CUBLASLT_ALGO_CONFIG_ID, &id);
3334     return id;
3335   }
3336 
3337  private:
3338   blas::AlgorithmType index_;
3339   cublasLtMatmulAlgo_t algo_;
3340   size_t workspace_size_;
3341 };
3342 
3343 port::StatusOr<UniqueMatmulPreference> CreateCublasLtMatmulPreference(
3344     const blas::IBlasLtMatmulPlan *plan, size_t max_workspace_bytes) {
3345   cublasLtMatmulPreference_t preference;
3346   cublasStatus_t status = cublasLtMatmulPreferenceCreate(&preference);
3347   if (status != CUBLAS_STATUS_SUCCESS) {
3348     return port::Status(port::error::INTERNAL,
3349                         absl::StrCat("cublasLtMatmulPreferenceCreate failed: ",
3350                                      ToString(status)));
3351   }
3352   UniqueMatmulPreference unique_preference(preference);
3353   SE_RETURN_IF_ERROR(SetCublasLtAttr(preference,
3354                                      CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
3355                                      max_workspace_bytes));
3356 
3357   const auto &cuda_plan = *static_cast<const CUDABlasLtMatmulPlan *>(plan);
3358   if (cuda_plan.params().batch_count == 0) {
3359     return unique_preference;
3360   }
3361   // This is a workaround for a known issue in cuBlasLt where the heuristic may
3362   // in rare cases select an algo that does not support the specified stride.
3363   // Specifying the alignment requirements manually like this avoids the issue.
3364   auto get_alignment_bytes = [](int64_t stride, blas::DataType dtype) {
3365     return (stride & -stride) * GetDataTypeSizeBytes(dtype);
3366   };
3367   if (cuda_plan.params().stride_a) {
3368     SE_RETURN_IF_ERROR(SetCublasLtAttr(
3369         preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES,
3370         (uint32)get_alignment_bytes(cuda_plan.params().stride_a,
3371                                     cuda_plan.params().ab_type)));
3372   }
3373   if (cuda_plan.params().stride_b) {
3374     SE_RETURN_IF_ERROR(SetCublasLtAttr(
3375         preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES,
3376         (uint32)get_alignment_bytes(cuda_plan.params().stride_b,
3377                                     cuda_plan.params().ab_type)));
3378   }
3379   if (cuda_plan.params().stride_c) {
3380     SE_RETURN_IF_ERROR(SetCublasLtAttr(
3381         preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES,
3382         (uint32)get_alignment_bytes(cuda_plan.params().stride_c,
3383                                     cuda_plan.params().c_type)));
3384   }
3385   if (cuda_plan.params().stride_c) {
3386     SE_RETURN_IF_ERROR(SetCublasLtAttr(
3387         preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES,
3388         (uint32)get_alignment_bytes(cuda_plan.params().stride_c,
3389                                     cuda_plan.params().c_type)));
3390   }
3391   return unique_preference;
3392 }
3393 
3394 }  // namespace
3395 
3396 #endif  // CUDA_VERSION >= 11000
3397 
3398 port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
3399 CUDABlas::CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams &p) {
3400 #if CUDA_VERSION >= 11000
3401   auto cuda_plan = std::make_unique<CUDABlasLtMatmulPlan>();
3402   SE_RETURN_IF_ERROR(cuda_plan->init(p));
3403   return static_cast<std::unique_ptr<blas::IBlasLtMatmulPlan>>(
3404       std::move(cuda_plan));
3405 #else
3406   return port::Status(
3407       port::error::UNIMPLEMENTED,
3408       "CreateBlasLtMatmulPlan is not supported with this version of CUDA");
3409 #endif
3410 }
3411 
3412 port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
3413 CUDABlas::GetBlasLtMatmulAlgorithmsInternal(const blas::IBlasLtMatmulPlan *plan,
3414                                             size_t max_workspace_size,
3415                                             int max_algorithm_count,
3416                                             bool for_remainder_batch) {
3417 #if CUDA_VERSION >= 11000
3418   SE_ASSIGN_OR_RETURN(UniqueMatmulPreference preference,
3419                       CreateCublasLtMatmulPreference(plan, max_workspace_size));
3420 
3421   std::vector<cublasLtMatmulHeuristicResult_t> results(max_algorithm_count);
3422   {
3423     absl::MutexLock lock(&mu_);
3424 
3425     CHECK(blasLt_ != nullptr);
3426 
3427     gpu::ScopedActivateExecutorContext sac{parent_};
3428 
3429     int found_algorithm_count = 0;
3430     const auto &cuda_plan = *static_cast<const CUDABlasLtMatmulPlan *>(plan);
3431     const auto &a_desc =
3432         for_remainder_batch ? cuda_plan.a_remainder_desc() : cuda_plan.a_desc();
3433     const auto &b_desc =
3434         for_remainder_batch ? cuda_plan.b_remainder_desc() : cuda_plan.b_desc();
3435     const auto &c_desc =
3436         for_remainder_batch ? cuda_plan.c_remainder_desc() : cuda_plan.c_desc();
3437     const auto &d_desc =
3438         for_remainder_batch ? cuda_plan.d_remainder_desc() : cuda_plan.d_desc();
3439     cublasStatus_t status = cublasLtMatmulAlgoGetHeuristic(
3440         blasLt_, cuda_plan.op_desc(), a_desc, b_desc, c_desc, d_desc,
3441         preference.get(), max_algorithm_count, results.data(),
3442         &found_algorithm_count);
3443     if (status != CUBLAS_STATUS_SUCCESS) {
3444       return port::Status(
3445           port::error::INTERNAL,
3446           absl::StrCat("cublasLtMatmulAlgoGetHeuristic failed: ",
3447                        ToString(status)));
3448     }
3449     results.resize(found_algorithm_count);
3450   }
3451 
3452   std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>> out_algorithms;
3453   out_algorithms.reserve(results.size());
3454   for (size_t i = 0; i < results.size(); ++i) {
3455     const auto &result = results[i];
3456     if (result.state != CUBLAS_STATUS_SUCCESS) continue;  // Skip failed algos
3457     out_algorithms.emplace_back(std::make_unique<CUDABlasLtMatmulAlgorithm>(
3458         i, result.algo, result.workspaceSize));
3459   }
3460   return out_algorithms;
3461 #else  // if CUDA_VERSION < 11000
3462   return port::Status(
3463       port::error::UNIMPLEMENTED,
3464       "GetBlasLtMatmulAlgorithms is not supported with this version of CUDA");
3465 #endif
3466 }
3467 
3468 port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
3469 CUDABlas::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
3470                                     size_t max_workspace_size,
3471                                     int max_algorithm_count) {
3472   return GetBlasLtMatmulAlgorithmsInternal(plan, max_workspace_size,
3473                                            max_algorithm_count);
3474 }
3475 
3476 #if CUDA_VERSION >= 11000
3477 bool CUDABlas::DoBlasLtMatmulInternal(
3478     Stream *stream, bool err_on_failure, const blas::IBlasLtMatmulPlan *plan,
3479     const HostOrDeviceScalar<void> &alpha, DeviceMemoryBase a,
3480     DeviceMemoryBase b, const HostOrDeviceScalar<void> &beta,
3481     DeviceMemoryBase c, DeviceMemoryBase d, ScratchAllocator *scratch_allocator,
3482     const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias) {
3483   const auto &cuda_plan = *static_cast<const CUDABlasLtMatmulPlan *>(plan);
3484   const auto &cuda_algo =
3485       *static_cast<const CUDABlasLtMatmulAlgorithm *>(algorithm);
3486 
3487   if (alpha.data_type() != cuda_plan.scale_type() ||
3488       beta.data_type() != cuda_plan.scale_type()) {
3489     VLOG(2) << "DoBlasLtMatmul returning false because alpha and beta types do "
3490                "not match plan: expected "
3491             << cuda_plan.c_type() << ", got alpha=" << alpha.data_type()
3492             << " beta=" << beta.data_type();
3493     return false;
3494   }
3495   if (alpha.is_pointer() != beta.is_pointer()) {
3496     VLOG(2) << "DoBlasLtMatmul returning false because one of `alpha` "
3497                "and `beta` is a pointer, but the other is not.";
3498     return false;
3499   }
3500   bool is_pointer_mode_host = !alpha.is_pointer();
3501   if ((cuda_plan.params().pointer_mode == blas::PointerMode::kHost) !=
3502       is_pointer_mode_host) {
3503     VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong "
3504                "pointer_mode for the given alpha/beta.";
3505     return false;
3506   }
3507   if ((cuda_plan.params().epilogue == blas::Epilogue::kBias ||
3508        cuda_plan.params().epilogue == blas::Epilogue::kBiasThenReLU) !=
3509       (bias != nullptr)) {
3510     VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong "
3511                "epilogue for the given bias pointer.";
3512     return false;
3513   }
3514   const void *alpha_ptr = alpha.is_pointer() ? alpha.opaque_pointer().opaque()
3515                                              : alpha.opaque_value();
3516   const void *beta_ptr =
3517       beta.is_pointer() ? beta.opaque_pointer().opaque() : beta.opaque_value();
3518 
3519   void *workspace = nullptr;
3520   if (cuda_algo.workspace_size()) {
3521     port::Status allocation_status = AllocateWorkspace(
3522         &workspace, scratch_allocator, cuda_algo.workspace_size());
3523     if (!allocation_status.ok()) {
3524       if (err_on_failure || VLOG_IS_ON(3)) {
3525         LOG(ERROR)
3526             << "Failed to allocate workspace for cublasLtMatmul algo with id: "
3527             << cuda_algo.algo_id() << " requiring "
3528             << cuda_algo.workspace_size() << " bytes of workspace";
3529       }
3530       return false;
3531     }
3532   }
3533 
3534   // This is only used when batch_count > kMaxBatchCount.
3535   std::unique_ptr<blas::IBlasLtMatmulAlgorithm> unique_remainder_algo;
3536   if (cuda_plan.remainder_batch_count()) {
3537     // There is no easy way to get the user-specified max workspace size here,
3538     // so we just allow a very small amount and don't worry too much about
3539     // performance because this is only used in rare cases. The same reasoning
3540     // applies to selection of the algorithm.
3541     size_t max_workspace_size = 4 * 1024 * 1024;  // 4 MiB
3542     auto status_or_algorithms =
3543         GetBlasLtMatmulAlgorithmsInternal(plan, max_workspace_size,
3544                                           /* max_algorithm_count = */ 1,
3545                                           /* for_remainder_batch = */ true);
3546     if (!status_or_algorithms.ok()) {
3547       if (err_on_failure || VLOG_IS_ON(3)) {
3548         LOG(ERROR) << "Failed to get algorithms for blasLt remainder batch.";
3549       }
3550       return false;
3551     }
3552     auto algorithms = status_or_algorithms.ConsumeValueOrDie();
3553     unique_remainder_algo = std::move(algorithms.front());
3554   }
3555 
3556   cudaStream_t cuda_stream = CUDAStream(stream);
3557 
3558   absl::MutexLock lock(&mu_);
3559 
3560   if (bias != nullptr) {
3561     if (!cuda_plan.SetBiasPointer(bias.opaque())) {
3562       VLOG(2) << "DoBlasLtMatmul returning false because setting the bias "
3563                  "pointer failed.";
3564       return false;
3565     }
3566   }
3567 
3568   CHECK(blasLt_ != nullptr);
3569 
3570   gpu::ScopedActivateExecutorContext sac{parent_};
3571 
3572   // Plan execution is broken down into repeat calls with capped_batch_count,
3573   // followed by a final call with remainder_batch_count.
3574   // Cases where batch_count <= kMaxBatchCount require only a single call (a
3575   // single loop iteration and no remainder).
3576   int ab_type_size = GetDataTypeSizeBytes(cuda_plan.params().ab_type);
3577   int c_type_size = GetDataTypeSizeBytes(cuda_plan.params().c_type);
3578   const char *a_ptr = static_cast<const char *>(a.opaque());
3579   const char *b_ptr = static_cast<const char *>(b.opaque());
3580   const char *c_ptr = static_cast<const char *>(c.opaque());
3581   char *d_ptr = static_cast<char *>(d.opaque());
3582   int capped_batch_count = cuda_plan.capped_batch_count();
3583   for (int batch = 0;
3584        batch + capped_batch_count <= cuda_plan.params().batch_count;
3585        batch += capped_batch_count) {
3586     cublasStatus_t ret = cublasLtMatmul(
3587         blasLt_, cuda_plan.op_desc(), alpha_ptr, a_ptr, cuda_plan.a_desc(),
3588         b_ptr, cuda_plan.b_desc(), beta_ptr, c_ptr, cuda_plan.c_desc(), d_ptr,
3589         cuda_plan.d_desc(), cuda_algo.algo(), workspace,
3590         cuda_algo.workspace_size(), cuda_stream);
3591     if (ret != CUBLAS_STATUS_SUCCESS) {
3592       if (err_on_failure || VLOG_IS_ON(3)) {
3593         LOG(ERROR) << "failed to run cublasLtMatmul routine: " << ToString(ret);
3594       }
3595       return false;
3596     }
3597     a_ptr += capped_batch_count * cuda_plan.params().stride_a * ab_type_size;
3598     b_ptr += capped_batch_count * cuda_plan.params().stride_b * ab_type_size;
3599     c_ptr += capped_batch_count * cuda_plan.params().stride_c * c_type_size;
3600     d_ptr += capped_batch_count * cuda_plan.params().stride_c * c_type_size;
3601   }
3602   // This is only used when batch_count > kMaxBatchCount.
3603   if (cuda_plan.remainder_batch_count()) {
3604     const auto &remainder_algo =
3605         *static_cast<const CUDABlasLtMatmulAlgorithm *>(
3606             unique_remainder_algo.get());
3607     if (remainder_algo.workspace_size()) {
3608       port::Status allocation_status = AllocateWorkspace(
3609           &workspace, scratch_allocator, remainder_algo.workspace_size());
3610       if (!allocation_status.ok()) {
3611         if (err_on_failure || VLOG_IS_ON(3)) {
3612           LOG(ERROR) << "Failed to allocate workspace for cublasLtMatmul algo "
3613                         "with id: "
3614                      << remainder_algo.algo_id() << " requiring "
3615                      << remainder_algo.workspace_size()
3616                      << " bytes of workspace";
3617         }
3618         return false;
3619       }
3620     }
3621     cublasStatus_t ret = cublasLtMatmul(
3622         blasLt_, cuda_plan.op_desc(), alpha_ptr, a_ptr,
3623         cuda_plan.a_remainder_desc(), b_ptr, cuda_plan.b_remainder_desc(),
3624         beta_ptr, c_ptr, cuda_plan.c_remainder_desc(), d_ptr,
3625         cuda_plan.d_remainder_desc(), remainder_algo.algo(), workspace,
3626         remainder_algo.workspace_size(), cuda_stream);
3627     if (ret != CUBLAS_STATUS_SUCCESS) {
3628       if (err_on_failure || VLOG_IS_ON(3)) {
3629         LOG(ERROR) << "failed to run remainder cublasLtMatmul routine: "
3630                    << ToString(ret);
3631       }
3632       return false;
3633     }
3634   }
3635   return true;
3636 }
3637 #endif  // CUDA_VERSION >= 11000
3638 
3639 bool CUDABlas::DoBlasLtMatmul(
3640     Stream *stream, const blas::IBlasLtMatmulPlan *plan,
3641     const HostOrDeviceScalar<void> &alpha, DeviceMemoryBase a,
3642     DeviceMemoryBase b, const HostOrDeviceScalar<void> &beta,
3643     DeviceMemoryBase c, ScratchAllocator *scratch_allocator,
3644     const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias,
3645     blas::ProfileResult *output_profile_result) {
3646 #if CUDA_VERSION >= 11000
3647   const auto &cuda_plan = *static_cast<const CUDABlasLtMatmulPlan *>(plan);
3648   HostOrDeviceScalar<void> alpha_cast = alpha;
3649   HostOrDeviceScalar<void> beta_cast = beta;
3650   if (cuda_plan.c_type() == blas::DataType::kHalf &&
3651       cuda_plan.scale_type() == blas::DataType::kFloat) {
3652     // The given alpha and beta types are F16 (they always match c), but F32*
3653     // computation type requires that they be F32, so we must cast them.
3654     if (alpha.is_pointer() || beta.is_pointer()) {
3655       // We cannot easily convert a pointer to f16 memory to a pointer to f32
3656       // memory from here, so we don't support this for now.
3657       return false;
3658     }
3659     alpha_cast = HostOrDeviceScalar<void>(
3660         static_cast<float>(alpha.value<Eigen::half>()));
3661     beta_cast =
3662         HostOrDeviceScalar<void>(static_cast<float>(beta.value<Eigen::half>()));
3663   }
3664 
3665   std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
3666   if (output_profile_result) {
3667     timer.reset(new GpuTimer(parent_));
3668     if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
3669       return false;
3670     }
3671   }
3672 
3673   bool err_on_failure = timer != nullptr;
3674   bool result = DoBlasLtMatmulInternal(stream, err_on_failure, plan, alpha_cast,
3675                                        a, b, beta_cast, c, c, scratch_allocator,
3676                                        algorithm, bias);
3677 
3678   if (timer && result) {
3679     // GpuTimer will CHECK-fail if we Stop() it while the stream is in an error
3680     // state.
3681     if (!timer->Stop(AsGpuStream(stream))) {
3682       return false;
3683     }
3684     output_profile_result->set_is_valid(true);
3685     output_profile_result->set_algorithm(algorithm->index());
3686     output_profile_result->set_elapsed_time_in_ms(
3687         timer->GetElapsedMilliseconds());
3688   }
3689   return result;
3690 #else  // if CUDA_VERSION < 11000
3691   return false;
3692 #endif
3693 }
3694 
3695 port::Status CUDABlas::GetVersion(std::string *version) {
3696   absl::MutexLock lock(&mu_);
3697 
3698   int v;
3699   auto status = cublasGetVersion(blas_, &v);
3700   if (status != CUBLAS_STATUS_SUCCESS) {
3701     return port::InternalError(ToString(status));
3702   }
3703   *version = std::to_string(v);
3704   return port::Status::OK();
3705 }
3706 
3707 }  // namespace gpu
3708 
initialize_cublas()3709 void initialize_cublas() {
3710   port::Status status =
3711       PluginRegistry::Instance()->RegisterFactory<PluginRegistry::BlasFactory>(
3712           cuda::kCudaPlatformId, gpu::kCuBlasPlugin, "cuBLAS",
3713           [](internal::StreamExecutorInterface *parent) -> blas::BlasSupport * {
3714             gpu::GpuExecutor *cuda_executor =
3715                 dynamic_cast<gpu::GpuExecutor *>(parent);
3716             if (cuda_executor == nullptr) {
3717               LOG(ERROR)
3718                   << "Attempting to initialize an instance of the cuBLAS "
3719                   << "support library with a non-CUDA StreamExecutor";
3720               return nullptr;
3721             }
3722 
3723             gpu::CUDABlas *blas = new gpu::CUDABlas(cuda_executor);
3724             if (!blas->Init()) {
3725               // Note: Init() will log a more specific error.
3726               delete blas;
3727               return nullptr;
3728             }
3729             return blas;
3730           });
3731 
3732   if (!status.ok()) {
3733     LOG(ERROR) << "Unable to register cuBLAS factory: "
3734                << status.error_message();
3735   }
3736 
3737   PluginRegistry::Instance()->SetDefaultFactory(
3738       cuda::kCudaPlatformId, PluginKind::kBlas, gpu::kCuBlasPlugin);
3739 }
3740 
3741 }  // namespace stream_executor
3742 
3743 REGISTER_MODULE_INITIALIZER(register_cublas,
3744                             { stream_executor::initialize_cublas(); });
3745