• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "third_party/gpus/cuda/include/cublasLt.h"
17 #include "third_party/gpus/cuda/include/cublas_v2.h"
18 #include "third_party/gpus/cuda/include/cuda.h"
19 
20 #define SE_CUDA_DATA_HALF CUDA_R_16F
21 
22 #include "tensorflow/stream_executor/cuda/cuda_blas.h"
23 
24 // Both Eigen Half.h and CUDA cuda_fp16.h provide similar typedef for __half. As
25 // such, there are two ways to get the typedef for __half:
26 //
27 // (1) Includes cuda_fp16.h and defines EIGEN_HAS_CUDA_FP16.
28 // (2) Neither includes cuda_fp16.h nor defines EIGEN_HAS_CUDA_FP16.
29 //
30 // Due to issue b/73793421, when the first approach is used and NVCC is used to
31 // compile this file, NVCC will complain duplicated definition for
32 // EIGEN_HAS_CUDA_FP16. On the other hand, when the second approach is used and
33 // clang is used to compile this file, clang will not understand __half
34 // due to missing the definition and macro EIGEN_HAS_CUDA_FP16.
35 //
36 // Because this file may be compiled with CLANG but will never be compiled with
37 // NVCC, we choose the first approach for CUDA < 9.0. For CUDA >= 9.0, we have
38 // to use the second approach because the data member in the __half defined
39 // by CUDA > 9.0 is `__x` while Eigen expects it to be `x`.
40 //
41 // TODO(b/73793421): Remove the following code block to switch to the second
42 // approach when the issue is fixed.
43 #if CUDA_VERSION < 9000
44 #include "third_party/gpus/cuda/include/cuda_fp16.h"
45 #define EIGEN_HAS_CUDA_FP16
46 #endif
47 
48 #include <complex>
49 
50 #include "absl/strings/str_cat.h"
51 #include "absl/strings/str_format.h"
52 #include "third_party/eigen3/Eigen/Core"
53 #include "tensorflow/core/platform/tensor_float_32_utils.h"
54 #include "tensorflow/core/util/env_var.h"
55 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
56 #include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
57 #include "tensorflow/stream_executor/cuda/cuda_helpers.h"
58 #include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
59 #include "tensorflow/stream_executor/cuda/cuda_stream.h"
60 #include "tensorflow/stream_executor/cuda/cuda_timer.h"
61 #include "tensorflow/stream_executor/device_memory.h"
62 #include "tensorflow/stream_executor/lib/env.h"
63 #include "tensorflow/stream_executor/lib/initialize.h"
64 #include "tensorflow/stream_executor/lib/status.h"
65 #include "tensorflow/stream_executor/lib/status_macros.h"
66 #include "tensorflow/stream_executor/platform/logging.h"
67 #include "tensorflow/stream_executor/platform/port.h"
68 #include "tensorflow/stream_executor/plugin_registry.h"
69 #include "tensorflow/stream_executor/scratch_allocator.h"
70 #include "tensorflow/stream_executor/stream_executor.h"
71 
72 namespace stream_executor {
73 namespace gpu {
74 
75 PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuBlasPlugin);
76 
ToString(cublasStatus_t status)77 static std::string ToString(cublasStatus_t status) {
78   switch (status) {
79     case CUBLAS_STATUS_SUCCESS:
80       return "CUBLAS_STATUS_SUCCESS";
81     case CUBLAS_STATUS_NOT_INITIALIZED:
82       return "CUBLAS_STATUS_NOT_INITIALIZED";
83     case CUBLAS_STATUS_ALLOC_FAILED:
84       return "CUBLAS_STATUS_ALLOC_FAILED";
85     case CUBLAS_STATUS_INVALID_VALUE:
86       return "CUBLAS_STATUS_INVALID_VALUE";
87     case CUBLAS_STATUS_ARCH_MISMATCH:
88       return "CUBLAS_STATUS_ARCH_MISMATCH";
89     case CUBLAS_STATUS_MAPPING_ERROR:
90       return "CUBLAS_STATUS_MAPPING_ERROR";
91     case CUBLAS_STATUS_EXECUTION_FAILED:
92       return "CUBLAS_STATUS_EXECUTION_FAILED";
93     case CUBLAS_STATUS_INTERNAL_ERROR:
94       return "CUBLAS_STATUS_INTERNAL_ERROR";
95 #if CUDA_VERSION >= 8000
96     case CUBLAS_STATUS_NOT_SUPPORTED:
97       return "CUBLAS_STATUS_NOT_SUPPORTED";
98     case CUBLAS_STATUS_LICENSE_ERROR:
99       return "CUBLAS_STATUS_LICENSE_ERROR";
100 #endif
101     default:
102       return absl::StrCat("<invalid cublas status: ", status, ">");
103   }
104 }
105 
106 // cuBLAS has interfaces that permit pointers to be passed from either the host
107 // memory space or the device memory space; however, you must instruct it as to
108 // which address space those pointers are in with cublasSetPointerMode.
109 //
110 // This helper sets the cuBLAS pointer mode to a desired value for a cuBLAS call
111 // you are about to perform in a given scope.
112 //
113 // The prior cuBLAS pointer mode is retained and restored when this object goes
114 // out of scope.
115 class ScopedCublasPointerMode {
116  public:
117   // Note that, because the setting of the cublas pointer mode is fallible,
118   // construction of this scoped datatype must be paired with a call to
119   // Init().
120   //
121   // Parameters:
122   //  handle: The cublas library handle to act upon in setting the pointer mode.
ScopedCublasPointerMode(cublasHandle_t handle)123   explicit ScopedCublasPointerMode(cublasHandle_t handle)
124       : handle_(handle), ok_(false) {}
125 
126   // Attempts the switch to the requested scoped pointer mode, new_mode.
127   //
128   // Note that when false is returned, an appropriate error has already been
129   // logged.
Init(cublasPointerMode_t new_mode)130   bool Init(cublasPointerMode_t new_mode) {
131     cublasStatus_t ret = cublasGetPointerMode(handle_, &old_mode_);
132     if (ret != CUBLAS_STATUS_SUCCESS) {
133       LOG(ERROR) << "failed to get old cublas pointer mode: " << ToString(ret);
134       return ok_ = false;
135     }
136 
137     ret = cublasSetPointerMode(handle_, new_mode);
138     if (ret != CUBLAS_STATUS_SUCCESS) {
139       LOG(ERROR) << "failed to set new cublas pointer mode: " << ToString(ret);
140       return ok_ = false;
141     }
142 
143     return ok_ = true;
144   }
145 
146   // Switches back to the prior pointer mode, if the switch operation was
147   // successful in the first place.
~ScopedCublasPointerMode()148   ~ScopedCublasPointerMode() {
149     if (ok_) {
150       cublasStatus_t ret = cublasSetPointerMode(handle_, old_mode_);
151       if (ret != CUBLAS_STATUS_SUCCESS) {
152         LOG(ERROR) << "failed to set former cublas pointer mode: "
153                    << ToString(ret);
154       }
155     }
156   }
157 
158  private:
159   cublasHandle_t handle_;  // Handle to the cuBLAS instance of interest.
160   cublasPointerMode_t old_mode_;  // Prior cuBLAS pointer mode, to be restored.
161   bool ok_;                       // Whether the change was successful.
162 };
163 
164 #if CUDA_VERSION >= 9000
165 // cuBLAS has interfaces that permit computations to use the Volta hardware.
166 // This must be enabled via the cublasGet/SetMathMode APIs.
167 //
168 // This helper sets the cuBLAS math mode to a desired value for a cuBLAS call
169 // you are about to perform in a given scope.
170 //
171 // The prior cuBLAS math mode is retained and restored when this object goes
172 // out of scope.
173 class ScopedCublasMathMode {
174  public:
175   // Note that, because the setting of the cublas math mode is fallible,
176   // construction of this scoped datatype must be paired with a call to
177   // Init().
178   //
179   // Parameters:
180   //  handle: The cublas library handle to act upon in setting the math mode.
ScopedCublasMathMode(cublasHandle_t handle)181   explicit ScopedCublasMathMode(cublasHandle_t handle)
182       : handle_(handle), ok_(false) {}
183 
184   // Attempts the switch to the requested scoped math mode, new_mode.
185   //
186   // Note that when false is returned, an appropriate error has already been
187   // logged.
Init(cublasMath_t new_mode)188   bool Init(cublasMath_t new_mode) {
189     cublasStatus_t ret = cublasGetMathMode(handle_, &old_mode_);
190     if (ret != CUBLAS_STATUS_SUCCESS) {
191       LOG(ERROR) << "failed to get old cublas math mode: " << ToString(ret);
192       return ok_ = false;
193     }
194 
195     ret = cublasSetMathMode(handle_, new_mode);
196     if (ret != CUBLAS_STATUS_SUCCESS) {
197       LOG(ERROR) << "failed to set new cublas math mode: " << ToString(ret);
198       return ok_ = false;
199     }
200     return ok_ = true;
201   }
202 
203   // Switches back to the prior math mode, if the switch operation was
204   // successful in the first place.
~ScopedCublasMathMode()205   ~ScopedCublasMathMode() {
206     if (ok_) {
207       cublasStatus_t ret = cublasSetMathMode(handle_, old_mode_);
208       if (ret != CUBLAS_STATUS_SUCCESS) {
209         LOG(ERROR) << "failed to set former cublas math mode: "
210                    << ToString(ret);
211       }
212     }
213   }
214 
215  private:
216   cublasHandle_t handle_;  // Handle to the cuBLAS instance of interest.
217   cublasMath_t old_mode_;  // Prior cuBLAS math mode, to be restored.
218   bool ok_;                // Whether the change was successful.
219 };
220 #endif  // CUDA_VERSION >= 9000
221 
Init()222 bool CUDABlas::Init() {
223   gpu::ScopedActivateExecutorContext sac{parent_};
224   cublasStatus_t ret = cublasCreate(&blas_);
225   if (ret != CUBLAS_STATUS_SUCCESS) {
226     LOG(ERROR) << "failed to create cublas handle: " << ToString(ret);
227     return false;
228   }
229 
230 #if CUDA_VERSION >= 11000
231   ret = cublasLtCreate(&blasLt_);
232   if (ret != CUBLAS_STATUS_SUCCESS) {
233     LOG(ERROR) << "failed to create cublasLt handle: " << ToString(ret);
234     return false;
235   }
236 #endif  // CUDA_VERSION >= 11000
237 
238   return true;
239 }
240 
CUDABlas(gpu::GpuExecutor * parent)241 CUDABlas::CUDABlas(gpu::GpuExecutor *parent)
242     : parent_(CHECK_NOTNULL(parent)),
243       blas_(nullptr)
244 #if CUDA_VERSION >= 11000
245       ,
246       blasLt_(nullptr)
247 #endif
248 {
249 }
250 
~CUDABlas()251 CUDABlas::~CUDABlas() {
252   if (blas_ != nullptr) {
253     gpu::ScopedActivateExecutorContext sac{parent_};
254     cublasDestroy(blas_);
255   }
256 #if CUDA_VERSION >= 11000
257   if (blasLt_ != nullptr) {
258     gpu::ScopedActivateExecutorContext sac{parent_};
259     cublasLtDestroy(blasLt_);
260   }
261 #endif
262 }
263 
SetStream(Stream * stream)264 bool CUDABlas::SetStream(Stream *stream) {
265   CHECK(stream != nullptr);
266   CHECK(AsGpuStreamValue(stream) != nullptr);
267   CHECK(blas_ != nullptr);
268   gpu::ScopedActivateExecutorContext sac{parent_};
269   cublasStatus_t ret = cublasSetStream(blas_, AsGpuStreamValue(stream));
270   if (ret != CUBLAS_STATUS_SUCCESS) {
271     LOG(ERROR) << "failed to set stream for cuBLAS calls: " << ToString(ret);
272     return false;
273   }
274 
275   return true;
276 }
277 
CUDAStream(Stream * stream)278 cudaStream_t CUDABlas::CUDAStream(Stream *stream) {
279   CHECK(stream != nullptr);
280   CHECK(AsGpuStreamValue(stream) != nullptr);
281   gpu::ScopedActivateExecutorContext sac{parent_};
282   return AsGpuStreamValue(stream);
283 }
284 
285 namespace {
286 
287 // Helper functions transforming blas arguments into cuBLAS arguments.
288 
CUDABlasTranspose(blas::Transpose trans)289 cublasOperation_t CUDABlasTranspose(blas::Transpose trans) {
290   switch (trans) {
291     case blas::Transpose::kNoTranspose:
292       return CUBLAS_OP_N;
293     case blas::Transpose::kTranspose:
294       return CUBLAS_OP_T;
295     case blas::Transpose::kConjugateTranspose:
296       return CUBLAS_OP_C;
297     default:
298       LOG(FATAL) << "Invalid value of blas::Transpose.";
299   }
300 }
301 
CUDABlasUpperLower(blas::UpperLower uplo)302 cublasFillMode_t CUDABlasUpperLower(blas::UpperLower uplo) {
303   switch (uplo) {
304     case blas::UpperLower::kUpper:
305       return CUBLAS_FILL_MODE_UPPER;
306     case blas::UpperLower::kLower:
307       return CUBLAS_FILL_MODE_LOWER;
308     default:
309       LOG(FATAL) << "Invalid value of blas::UpperLower.";
310   }
311 }
312 
CUDABlasDiagonal(blas::Diagonal diag)313 cublasDiagType_t CUDABlasDiagonal(blas::Diagonal diag) {
314   switch (diag) {
315     case blas::Diagonal::kUnit:
316       return CUBLAS_DIAG_UNIT;
317     case blas::Diagonal::kNonUnit:
318       return CUBLAS_DIAG_NON_UNIT;
319     default:
320       LOG(FATAL) << "Invalid value of blas::Diagonal.";
321   }
322 }
323 
CUDABlasSide(blas::Side side)324 cublasSideMode_t CUDABlasSide(blas::Side side) {
325   switch (side) {
326     case blas::Side::kLeft:
327       return CUBLAS_SIDE_LEFT;
328     case blas::Side::kRight:
329       return CUBLAS_SIDE_RIGHT;
330     default:
331       LOG(FATAL) << "Invalid value of blas::Side.";
332   }
333 }
334 
335 // CUDADataType<T>::type translates from a C++ type (e.g. float) to a
336 // cudaDataType_t (e.g. CUDA_R_32F).  CUDAComputationType(ty) translates from a
337 // blas::ComputationType to a cudaDataType_t.
338 //
339 // These are used to build the argument type and computation type args to
340 // cublasGemmEx.
341 template <typename T>
342 struct CUDADataType;
343 
344 template <>
345 struct CUDADataType<Eigen::half> {
346   static constexpr cudaDataType_t type = SE_CUDA_DATA_HALF;
347 };
348 
349 template <>
350 struct CUDADataType<std::complex<Eigen::half>> {
351   static constexpr cudaDataType_t type = CUDA_C_16F;
352 };
353 
354 template <>
355 struct CUDADataType<float> {
356   static constexpr cudaDataType_t type = CUDA_R_32F;
357 };
358 
359 template <>
360 struct CUDADataType<std::complex<float>> {
361   static constexpr cudaDataType_t type = CUDA_C_32F;
362 };
363 
364 template <>
365 struct CUDADataType<double> {
366   static constexpr cudaDataType_t type = CUDA_R_64F;
367 };
368 
369 template <>
370 struct CUDADataType<std::complex<double>> {
371   static constexpr cudaDataType_t type = CUDA_C_64F;
372 };
373 
374 template <>
375 struct CUDADataType<int> {
376   static constexpr cudaDataType_t type = CUDA_R_32I;
377 };
378 
379 template <>
380 struct CUDADataType<int8> {
381   static constexpr cudaDataType_t type = CUDA_R_8I;
382 };
383 
384 template <>
385 struct CUDADataType<std::complex<int8>> {
386   static constexpr cudaDataType_t type = CUDA_C_8I;
387 };
388 
389 template <>
390 struct CUDADataType<uint8> {
391   static constexpr cudaDataType_t type = CUDA_R_8U;
392 };
393 
394 template <>
395 struct CUDADataType<std::complex<uint8>> {
396   static constexpr cudaDataType_t type = CUDA_C_8U;
397 };
398 
CUDAComputationType(blas::ComputationType ty)399 cudaDataType_t CUDAComputationType(blas::ComputationType ty) {
400   switch (ty) {
401     case blas::ComputationType::kF16:
402       return CUDA_R_16F;
403     case blas::ComputationType::kF32:
404       return CUDA_R_32F;
405     case blas::ComputationType::kF64:
406       return CUDA_R_64F;
407     case blas::ComputationType::kI32:
408       return CUDA_R_32I;
409     case blas::ComputationType::kComplexF32:
410       return CUDA_C_32F;
411     case blas::ComputationType::kComplexF64:
412       return CUDA_C_64F;
413     case blas::ComputationType::kTF32AsF32:  // fall-through
414     case blas::ComputationType::kBF16AsF32:
415       // These cases are currently only supported in the blasLt routines, which
416       // use CUBLASComputationType() instead.
417       LOG(FATAL) << "Invalid value of blas::ComputationType.";
418   }
419 }
420 
421 #if CUDA_VERSION >= 11000
CUBLASComputationType(blas::ComputationType ty)422 cublasComputeType_t CUBLASComputationType(blas::ComputationType ty) {
423   switch (ty) {
424     case blas::ComputationType::kF16:
425       return CUBLAS_COMPUTE_16F;
426     case blas::ComputationType::kF32:  // fall-through
427     case blas::ComputationType::kComplexF32:
428       return CUBLAS_COMPUTE_32F;
429     case blas::ComputationType::kF64:  // fall-through
430     case blas::ComputationType::kComplexF64:
431       return CUBLAS_COMPUTE_64F;
432     case blas::ComputationType::kI32:
433       return CUBLAS_COMPUTE_32I;
434     case blas::ComputationType::kTF32AsF32:
435       return CUBLAS_COMPUTE_32F_FAST_TF32;
436     case blas::ComputationType::kBF16AsF32:
437       return CUBLAS_COMPUTE_32F_FAST_16BF;
438   }
439 }
440 #endif  // CUDA_VERSION >= 11000
441 
GetScaleType(blas::DataType data_type,blas::ComputationType compute_type)442 blas::DataType GetScaleType(blas::DataType data_type,
443                             blas::ComputationType compute_type) {
444   bool is_complex = data_type == blas::DataType::kComplexFloat ||
445                     data_type == blas::DataType::kComplexDouble;
446   switch (compute_type) {
447     case blas::ComputationType::kF16:
448       return blas::DataType::kHalf;
449     case blas::ComputationType::kF32:         // fall-through
450     case blas::ComputationType::kComplexF32:  // fall-through
451     case blas::ComputationType::kTF32AsF32:   // fall-through
452     case blas::ComputationType::kBF16AsF32:
453       return is_complex ? blas::DataType::kComplexFloat
454                         : blas::DataType::kFloat;
455     case blas::ComputationType::kF64:  // fall-through
456     case blas::ComputationType::kComplexF64:
457       return is_complex ? blas::DataType::kComplexDouble
458                         : blas::DataType::kDouble;
459     case blas::ComputationType::kI32:
460       return blas::DataType::kInt32;
461   }
462 }
463 
464 #if CUDA_VERSION >= 11000
CUBLASPointerMode(blas::PointerMode pointer_mode)465 cublasLtPointerMode_t CUBLASPointerMode(blas::PointerMode pointer_mode) {
466   switch (pointer_mode) {
467     case blas::PointerMode::kHost:
468       return CUBLASLT_POINTER_MODE_HOST;
469     case blas::PointerMode::kDevice:
470       return CUBLASLT_POINTER_MODE_DEVICE;
471   }
472 }
CUBLASEpilogue(blas::Epilogue epilogue)473 cublasLtEpilogue_t CUBLASEpilogue(blas::Epilogue epilogue) {
474   switch (epilogue) {
475     case blas::Epilogue::kDefault:
476       return CUBLASLT_EPILOGUE_DEFAULT;
477     case blas::Epilogue::kReLU:
478       return CUBLASLT_EPILOGUE_RELU;
479     case blas::Epilogue::kBias:
480       return CUBLASLT_EPILOGUE_BIAS;
481     case blas::Epilogue::kBiasThenReLU:
482       return CUBLASLT_EPILOGUE_RELU_BIAS;
483   }
484 }
485 #endif  // CUDA_VERSION >= 11000
486 
GetCUDADataType(blas::DataType ty)487 cudaDataType_t GetCUDADataType(blas::DataType ty) {
488   switch (ty) {
489     case blas::DataType::kHalf:
490       return CUDA_R_16F;
491     case blas::DataType::kFloat:
492       return CUDA_R_32F;
493     case blas::DataType::kDouble:
494       return CUDA_R_64F;
495     case blas::DataType::kInt8:
496       return CUDA_R_8I;
497     case blas::DataType::kInt32:
498       return CUDA_R_32I;
499     case blas::DataType::kComplexFloat:
500       return CUDA_C_32F;
501     case blas::DataType::kComplexDouble:
502       return CUDA_C_64F;
503     default:
504       LOG(FATAL) << "Invalid value of blas::DataType in GetCUDADataType";
505   }
506 }
507 
GetDataTypeSizeBytes(blas::DataType ty)508 int GetDataTypeSizeBytes(blas::DataType ty) {
509   switch (ty) {
510     case blas::DataType::kHalf:
511       return 2;
512     case blas::DataType::kFloat:
513       return 4;
514     case blas::DataType::kDouble:
515       return 8;
516     case blas::DataType::kInt8:
517       return 1;
518     case blas::DataType::kInt32:
519       return 4;
520     case blas::DataType::kComplexFloat:
521       return 8;
522     case blas::DataType::kComplexDouble:
523       return 16;
524     default:
525       LOG(FATAL) << "Invalid value of blas::DataType in GetDataTypeSizeBytes";
526   }
527 }
528 
529 }  // namespace
530 
531 template <typename FuncT, typename... Args>
DoBlasInternalImpl(FuncT cublas_func,Stream * stream,bool pointer_mode_host,bool err_on_failure,cublasMath_t math_type,Args...args)532 bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
533                                   bool pointer_mode_host, bool err_on_failure,
534                                   cublasMath_t math_type, Args... args) {
535   absl::MutexLock lock(&mu_);
536 
537   CHECK(blas_ != nullptr);
538   if (!SetStream(stream)) {
539     return false;
540   }
541 
542 #if CUDA_VERSION >= 9000
543   ScopedCublasMathMode math_mode{blas_};
544 #if CUBLAS_VER_MAJOR >= 11
545   if (math_type == CUBLAS_TF32_TENSOR_OP_MATH &&
546       tensorflow::tensor_float_32_execution_enabled()) {
547 #else
548   if (math_type == CUBLAS_TENSOR_OP_MATH) {
549 #endif
550     if (!math_mode.Init(math_type)) {
551       return false;
552     }
553   }
554 #endif
555 
556   gpu::ScopedActivateExecutorContext sac{parent_};
557   ScopedCublasPointerMode pointer_mode{blas_};
558   if (!pointer_mode.Init(pointer_mode_host ? CUBLAS_POINTER_MODE_HOST
559                                            : CUBLAS_POINTER_MODE_DEVICE)) {
560     return false;
561   }
562   cublasStatus_t ret = cublas_func(blas_, args...);
563   if ((err_on_failure || VLOG_IS_ON(3)) && ret != CUBLAS_STATUS_SUCCESS) {
564     LOG(ERROR) << "failed to run cuBLAS routine: " << ToString(ret);
565   }
566   return ret == CUBLAS_STATUS_SUCCESS;
567 }
568 
569 // cublas_func may be overloaded, so we need to figure out which one we really
570 // need to call based on the args. One way to do it is to wrap it in lambda.
571 #define AS_LAMBDA(func)                                                  \
572   [](auto &&... args) -> decltype(                                       \
573                           func(std::forward<decltype(args)>(args)...)) { \
574     return func(std::forward<decltype(args)>(args)...);                  \
575   }
576 
577 bool CUDABlas::DoBlasAsum(Stream *stream, uint64 elem_count,
578                           const DeviceMemory<float> &x, int incx,
579                           DeviceMemory<float> *result) {
580   return DoBlasInternal(cublasSasum, stream, false /* = pointer_mode_host */,
581                         elem_count, GpuMemory(x), incx,
582                         GpuMemoryMutable(result));
583 }
584 
585 bool CUDABlas::DoBlasAsum(Stream *stream, uint64 elem_count,
586                           const DeviceMemory<double> &x, int incx,
587                           DeviceMemory<double> *result) {
588   return DoBlasInternal(cublasDasum, stream, false /* = pointer_mode_host */,
589                         elem_count, GpuMemory(x), incx,
590                         GpuMemoryMutable(result));
591 }
592 
593 bool CUDABlas::DoBlasAsum(Stream *stream, uint64 elem_count,
594                           const DeviceMemory<std::complex<float>> &x, int incx,
595                           DeviceMemory<float> *result) {
596   return DoBlasInternal(cublasScasum, stream, false /* = pointer_mode_host */,
597                         elem_count, GpuComplex(GpuMemory(x)), incx,
598                         GpuMemoryMutable(result));
599 }
600 
601 bool CUDABlas::DoBlasAsum(Stream *stream, uint64 elem_count,
602                           const DeviceMemory<std::complex<double>> &x, int incx,
603                           DeviceMemory<double> *result) {
604   return DoBlasInternal(cublasDzasum, stream, false /* = pointer_mode_host */,
605                         elem_count, GpuComplex(GpuMemory(x)), incx,
606                         GpuMemoryMutable(result));
607 }
608 
609 bool CUDABlas::DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha,
610                           const DeviceMemory<float> &x, int incx,
611                           DeviceMemory<float> *y, int incy) {
612   return DoBlasInternal(cublasSaxpy, stream, true /* = pointer_mode_host */,
613                         elem_count, &alpha, GpuMemory(x), incx,
614                         GpuMemoryMutable(y), incy);
615 }
616 
617 bool CUDABlas::DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha,
618                           const DeviceMemory<double> &x, int incx,
619                           DeviceMemory<double> *y, int incy) {
620   return DoBlasInternal(cublasDaxpy, stream, true /* = pointer_mode_host */,
621                         elem_count, &alpha, GpuMemory(x), incx,
622                         GpuMemoryMutable(y), incy);
623 }
624 
625 bool CUDABlas::DoBlasAxpy(Stream *stream, uint64 elem_count,
626                           std::complex<float> alpha,
627                           const DeviceMemory<std::complex<float>> &x, int incx,
628                           DeviceMemory<std::complex<float>> *y, int incy) {
629   auto cb_alpha = GpuComplexValue(alpha);
630   return DoBlasInternal(cublasCaxpy, stream, true /* = pointer_mode_host */,
631                         elem_count, GpuComplex(&cb_alpha),
632                         GpuComplex(GpuMemory(x)), incx,
633                         GpuComplex(GpuMemoryMutable(y)), incy);
634 }
635 
636 bool CUDABlas::DoBlasAxpy(Stream *stream, uint64 elem_count,
637                           std::complex<double> alpha,
638                           const DeviceMemory<std::complex<double>> &x, int incx,
639                           DeviceMemory<std::complex<double>> *y, int incy) {
640   auto cb_alpha = GpuComplexValue(alpha);
641   return DoBlasInternal(cublasZaxpy, stream, true /* = pointer_mode_host */,
642                         elem_count, GpuComplex(&cb_alpha),
643                         GpuComplex(GpuMemory(x)), incx,
644                         GpuComplex(GpuMemoryMutable(y)), incy);
645 }
646 
647 bool CUDABlas::DoBlasCopy(Stream *stream, uint64 elem_count,
648                           const DeviceMemory<float> &x, int incx,
649                           DeviceMemory<float> *y, int incy) {
650   return DoBlasInternal(cublasScopy, stream, true /* = pointer_mode_host */,
651                         elem_count, GpuMemory(x), incx, GpuMemoryMutable(y),
652                         incy);
653 }
654 
655 bool CUDABlas::DoBlasCopy(Stream *stream, uint64 elem_count,
656                           const DeviceMemory<double> &x, int incx,
657                           DeviceMemory<double> *y, int incy) {
658   return DoBlasInternal(cublasDcopy, stream, true /* = pointer_mode_host */,
659                         elem_count, GpuMemory(x), incx, GpuMemoryMutable(y),
660                         incy);
661 }
662 
663 bool CUDABlas::DoBlasCopy(Stream *stream, uint64 elem_count,
664                           const DeviceMemory<std::complex<float>> &x, int incx,
665                           DeviceMemory<std::complex<float>> *y, int incy) {
666   return DoBlasInternal(cublasCcopy, stream, true /* = pointer_mode_host */,
667                         elem_count, GpuComplex(GpuMemory(x)), incx,
668                         GpuComplex(GpuMemoryMutable(y)), incy);
669 }
670 
671 bool CUDABlas::DoBlasCopy(Stream *stream, uint64 elem_count,
672                           const DeviceMemory<std::complex<double>> &x, int incx,
673                           DeviceMemory<std::complex<double>> *y, int incy) {
674   return DoBlasInternal(cublasZcopy, stream, true /* = pointer_mode_host */,
675                         elem_count, GpuComplex(GpuMemory(x)), incx,
676                         GpuComplex(GpuMemoryMutable(y)), incy);
677 }
678 
679 bool CUDABlas::DoBlasDot(Stream *stream, uint64 elem_count,
680                          const DeviceMemory<float> &x, int incx,
681                          const DeviceMemory<float> &y, int incy,
682                          DeviceMemory<float> *result) {
683   return DoBlasInternal(cublasSdot, stream, false /* = pointer_mode_host */,
684                         elem_count, GpuMemory(x), incx, GpuMemory(y), incy,
685                         GpuMemoryMutable(result));
686 }
687 
688 bool CUDABlas::DoBlasDot(Stream *stream, uint64 elem_count,
689                          const DeviceMemory<double> &x, int incx,
690                          const DeviceMemory<double> &y, int incy,
691                          DeviceMemory<double> *result) {
692   return DoBlasInternal(cublasDdot, stream, false /* = pointer_mode_host */,
693                         elem_count, GpuMemory(x), incx, GpuMemory(y), incy,
694                         GpuMemoryMutable(result));
695 }
696 
697 bool CUDABlas::DoBlasDotc(Stream *stream, uint64 elem_count,
698                           const DeviceMemory<std::complex<float>> &x, int incx,
699                           const DeviceMemory<std::complex<float>> &y, int incy,
700                           DeviceMemory<std::complex<float>> *result) {
701   return DoBlasInternal(cublasCdotc, stream, false /* = pointer_mode_host */,
702                         elem_count, GpuComplex(GpuMemory(x)), incx,
703                         GpuComplex(GpuMemory(y)), incy,
704                         GpuComplex(GpuMemoryMutable(result)));
705 }
706 
707 bool CUDABlas::DoBlasDotc(Stream *stream, uint64 elem_count,
708                           const DeviceMemory<std::complex<double>> &x, int incx,
709                           const DeviceMemory<std::complex<double>> &y, int incy,
710                           DeviceMemory<std::complex<double>> *result) {
711   return DoBlasInternal(cublasZdotc, stream, false /* = pointer_mode_host */,
712                         elem_count, GpuComplex(GpuMemory(x)), incx,
713                         GpuComplex(GpuMemory(y)), incy,
714                         GpuComplex(GpuMemoryMutable(result)));
715 }
716 
717 bool CUDABlas::DoBlasDotu(Stream *stream, uint64 elem_count,
718                           const DeviceMemory<std::complex<float>> &x, int incx,
719                           const DeviceMemory<std::complex<float>> &y, int incy,
720                           DeviceMemory<std::complex<float>> *result) {
721   return DoBlasInternal(cublasCdotu, stream, false /* = pointer_mode_host */,
722                         elem_count, GpuComplex(GpuMemory(x)), incx,
723                         GpuComplex(GpuMemory(y)), incy,
724                         GpuComplex(GpuMemoryMutable(result)));
725 }
726 
727 bool CUDABlas::DoBlasDotu(Stream *stream, uint64 elem_count,
728                           const DeviceMemory<std::complex<double>> &x, int incx,
729                           const DeviceMemory<std::complex<double>> &y, int incy,
730                           DeviceMemory<std::complex<double>> *result) {
731   return DoBlasInternal(cublasZdotu, stream, false /* = pointer_mode_host */,
732                         elem_count, GpuComplex(GpuMemory(x)), incx,
733                         GpuComplex(GpuMemory(y)), incy,
734                         GpuComplex(GpuMemoryMutable(result)));
735 }
736 
737 bool CUDABlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
738                           const DeviceMemory<float> &x, int incx,
739                           DeviceMemory<float> *result) {
740   return DoBlasInternal(cublasSnrm2, stream, false /* = pointer_mode_host */,
741                         elem_count, GpuMemory(x), incx,
742                         GpuMemoryMutable(result));
743 }
744 
745 bool CUDABlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
746                           const DeviceMemory<double> &x, int incx,
747                           DeviceMemory<double> *result) {
748   return DoBlasInternal(cublasDnrm2, stream, false /* = pointer_mode_host */,
749                         elem_count, GpuMemory(x), incx,
750                         GpuMemoryMutable(result));
751 }
752 
753 bool CUDABlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
754                           const DeviceMemory<std::complex<float>> &x, int incx,
755                           DeviceMemory<float> *result) {
756   return DoBlasInternal(cublasScnrm2, stream, false /* = pointer_mode_host */,
757                         elem_count, GpuComplex(GpuMemory(x)), incx,
758                         GpuMemoryMutable(result));
759 }
760 
761 bool CUDABlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
762                           const DeviceMemory<std::complex<double>> &x, int incx,
763                           DeviceMemory<double> *result) {
764   return DoBlasInternal(cublasDznrm2, stream, false /* = pointer_mode_host */,
765                         elem_count, GpuComplex(GpuMemory(x)), incx,
766                         GpuMemoryMutable(result));
767 }
768 
769 bool CUDABlas::DoBlasRot(Stream *stream, uint64 elem_count,
770                          DeviceMemory<float> *x, int incx,
771                          DeviceMemory<float> *y, int incy, float c, float s) {
772   return DoBlasInternal(cublasSrot, stream, true /* = pointer_mode_host */,
773                         elem_count, GpuMemoryMutable(x), incx,
774                         GpuMemoryMutable(y), incy, &c, &s);
775 }
776 
777 bool CUDABlas::DoBlasRot(Stream *stream, uint64 elem_count,
778                          DeviceMemory<double> *x, int incx,
779                          DeviceMemory<double> *y, int incy, double c,
780                          double s) {
781   return DoBlasInternal(cublasDrot, stream, true /* = pointer_mode_host */,
782                         elem_count, GpuMemoryMutable(x), incx,
783                         GpuMemoryMutable(y), incy, &c, &s);
784 }
785 
786 bool CUDABlas::DoBlasRot(Stream *stream, uint64 elem_count,
787                          DeviceMemory<std::complex<float>> *x, int incx,
788                          DeviceMemory<std::complex<float>> *y, int incy,
789                          float c, float s) {
790   return DoBlasInternal(cublasCsrot, stream, true /* = pointer_mode_host */,
791                         elem_count, GpuComplex(GpuMemoryMutable(x)), incx,
792                         GpuComplex(GpuMemoryMutable(y)), incy, &c, &s);
793 }
794 
795 bool CUDABlas::DoBlasRot(Stream *stream, uint64 elem_count,
796                          DeviceMemory<std::complex<double>> *x, int incx,
797                          DeviceMemory<std::complex<double>> *y, int incy,
798                          double c, double s) {
799   return DoBlasInternal(cublasZdrot, stream, true /* = pointer_mode_host */,
800                         elem_count, GpuComplex(GpuMemoryMutable(x)), incx,
801                         GpuComplex(GpuMemoryMutable(y)), incy, &c, &s);
802 }
803 
804 bool CUDABlas::DoBlasRotg(Stream *stream, DeviceMemory<float> *a,
805                           DeviceMemory<float> *b, DeviceMemory<float> *c,
806                           DeviceMemory<float> *s) {
807   return DoBlasInternal(cublasSrotg, stream, false /* = pointer_mode_host */,
808                         GpuMemoryMutable(a), GpuMemoryMutable(b),
809                         GpuMemoryMutable(c), GpuMemoryMutable(s));
810 }
811 
812 bool CUDABlas::DoBlasRotg(Stream *stream, DeviceMemory<double> *a,
813                           DeviceMemory<double> *b, DeviceMemory<double> *c,
814                           DeviceMemory<double> *s) {
815   return DoBlasInternal(cublasDrotg, stream, false /* = pointer_mode_host */,
816                         GpuComplex(GpuMemoryMutable(a)), GpuMemoryMutable(b),
817                         GpuMemoryMutable(c), GpuMemoryMutable(s));
818 }
819 
820 bool CUDABlas::DoBlasRotg(Stream *stream, DeviceMemory<std::complex<float>> *a,
821                           DeviceMemory<std::complex<float>> *b,
822                           DeviceMemory<float> *c,
823                           DeviceMemory<std::complex<float>> *s) {
824   return DoBlasInternal(
825       cublasCrotg, stream, false /* = pointer_mode_host */,
826       GpuComplex(GpuMemoryMutable(a)), GpuComplex(GpuMemoryMutable(b)),
827       GpuComplex(GpuMemoryMutable(c)), GpuComplex(GpuMemoryMutable(s)));
828 }
829 
830 bool CUDABlas::DoBlasRotg(Stream *stream, DeviceMemory<std::complex<double>> *a,
831                           DeviceMemory<std::complex<double>> *b,
832                           DeviceMemory<double> *c,
833                           DeviceMemory<std::complex<double>> *s) {
834   return DoBlasInternal(
835       cublasZrotg, stream, false /* = pointer_mode_host */,
836       GpuComplex(GpuMemoryMutable(a)), GpuComplex(GpuMemoryMutable(b)),
837       GpuComplex(GpuMemoryMutable(c)), GpuComplex(GpuMemoryMutable(s)));
838 }
839 
840 bool CUDABlas::DoBlasRotm(Stream *stream, uint64 elem_count,
841                           DeviceMemory<float> *x, int incx,
842                           DeviceMemory<float> *y, int incy,
843                           const DeviceMemory<float> &param) {
844   return DoBlasInternal(cublasSrotm, stream, false /* = pointer_mode_host */,
845                         elem_count, GpuMemoryMutable(x), incx,
846                         GpuMemoryMutable(y), incy, GpuMemory(param));
847 }
848 
849 bool CUDABlas::DoBlasRotm(Stream *stream, uint64 elem_count,
850                           DeviceMemory<double> *x, int incx,
851                           DeviceMemory<double> *y, int incy,
852                           const DeviceMemory<double> &param) {
853   return DoBlasInternal(cublasDrotm, stream, false /* = pointer_mode_host */,
854                         elem_count, GpuMemoryMutable(x), incx,
855                         GpuMemoryMutable(y), incy, GpuMemory(param));
856 }
857 
858 bool CUDABlas::DoBlasRotmg(Stream *stream, DeviceMemory<float> *d1,
859                            DeviceMemory<float> *d2, DeviceMemory<float> *x1,
860                            const DeviceMemory<float> &y1,
861                            DeviceMemory<float> *param) {
862   return DoBlasInternal(cublasSrotmg, stream, false /* = pointer_mode_host */,
863                         GpuMemoryMutable(d1), GpuMemoryMutable(d2),
864                         GpuMemoryMutable(x1), GpuMemory(y1),
865                         GpuMemoryMutable(param));
866 }
867 
868 bool CUDABlas::DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1,
869                            DeviceMemory<double> *d2, DeviceMemory<double> *x1,
870                            const DeviceMemory<double> &y1,
871                            DeviceMemory<double> *param) {
872   return DoBlasInternal(cublasDrotmg, stream, false /* = pointer_mode_host */,
873                         GpuMemoryMutable(d1), GpuMemoryMutable(d2),
874                         GpuMemoryMutable(x1), GpuMemory(y1),
875                         GpuMemoryMutable(param));
876 }
877 
878 bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
879                           DeviceMemory<float> *x, int incx) {
880   return DoBlasInternal(cublasSscal, stream, true /* = pointer_mode_host */,
881                         elem_count, &alpha, GpuMemoryMutable(x), incx);
882 }
883 
884 bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
885                           DeviceMemory<double> *x, int incx) {
886   return DoBlasInternal(cublasDscal, stream, true /* = pointer_mode_host */,
887                         elem_count, &alpha, GpuMemoryMutable(x), incx);
888 }
889 
890 bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
891                           DeviceMemory<std::complex<float>> *x, int incx) {
892   return DoBlasInternal(cublasCsscal, stream, true /* = pointer_mode_host */,
893                         elem_count, &alpha, GpuComplex(GpuMemoryMutable(x)),
894                         incx);
895 }
896 
897 bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
898                           DeviceMemory<std::complex<double>> *x, int incx) {
899   return DoBlasInternal(cublasZdscal, stream, true /* = pointer_mode_host */,
900                         elem_count, &alpha, GpuComplex(GpuMemoryMutable(x)),
901                         incx);
902 }
903 
904 bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count,
905                           std::complex<float> alpha,
906                           DeviceMemory<std::complex<float>> *x, int incx) {
907   auto cb_alpha = GpuComplexValue(alpha);
908   return DoBlasInternal(cublasCscal, stream, true /* = pointer_mode_host */,
909                         elem_count, GpuComplex(&cb_alpha),
910                         GpuComplex(GpuMemoryMutable(x)), incx);
911 }
912 
913 bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count,
914                           std::complex<double> alpha,
915                           DeviceMemory<std::complex<double>> *x, int incx) {
916   auto cb_alpha = GpuComplexValue(alpha);
917   return DoBlasInternal(cublasZscal, stream, true /* = pointer_mode_host */,
918                         elem_count, GpuComplex(&cb_alpha),
919                         GpuComplex(GpuMemoryMutable(x)), incx);
920 }
921 
922 bool CUDABlas::DoBlasSwap(Stream *stream, uint64 elem_count,
923                           DeviceMemory<float> *x, int incx,
924                           DeviceMemory<float> *y, int incy) {
925   return DoBlasInternal(cublasSswap, stream, true /* = pointer_mode_host */,
926                         elem_count, GpuMemoryMutable(x), incx,
927                         GpuMemoryMutable(y), incy);
928 }
929 
930 bool CUDABlas::DoBlasSwap(Stream *stream, uint64 elem_count,
931                           DeviceMemory<double> *x, int incx,
932                           DeviceMemory<double> *y, int incy) {
933   return DoBlasInternal(cublasDswap, stream, true /* = pointer_mode_host */,
934                         elem_count, GpuMemoryMutable(x), incx,
935                         GpuMemoryMutable(y), incy);
936 }
937 
938 bool CUDABlas::DoBlasSwap(Stream *stream, uint64 elem_count,
939                           DeviceMemory<std::complex<float>> *x, int incx,
940                           DeviceMemory<std::complex<float>> *y, int incy) {
941   return DoBlasInternal(cublasCswap, stream, true /* = pointer_mode_host */,
942                         elem_count, GpuComplex(GpuMemoryMutable(x)), incx,
943                         GpuComplex(GpuMemoryMutable(y)), incy);
944 }
945 
946 bool CUDABlas::DoBlasSwap(Stream *stream, uint64 elem_count,
947                           DeviceMemory<std::complex<double>> *x, int incx,
948                           DeviceMemory<std::complex<double>> *y, int incy) {
949   return DoBlasInternal(cublasZswap, stream, true /* = pointer_mode_host */,
950                         elem_count, GpuComplex(GpuMemoryMutable(x)), incx,
951                         GpuComplex(GpuMemoryMutable(y)), incy);
952 }
953 
954 bool CUDABlas::DoBlasIamax(Stream *stream, uint64 elem_count,
955                            const DeviceMemory<float> &x, int incx,
956                            DeviceMemory<int> *result) {
957   return DoBlasInternal(cublasIsamax, stream, false /* = pointer_mode_host */,
958                         elem_count, GpuMemory(x), incx,
959                         GpuMemoryMutable(result));
960 }
961 
962 bool CUDABlas::DoBlasIamax(Stream *stream, uint64 elem_count,
963                            const DeviceMemory<double> &x, int incx,
964                            DeviceMemory<int> *result) {
965   return DoBlasInternal(cublasIdamax, stream, false /* = pointer_mode_host */,
966                         elem_count, GpuMemory(x), incx,
967                         GpuMemoryMutable(result));
968 }
969 
970 bool CUDABlas::DoBlasIamax(Stream *stream, uint64 elem_count,
971                            const DeviceMemory<std::complex<float>> &x, int incx,
972                            DeviceMemory<int> *result) {
973   return DoBlasInternal(cublasIcamax, stream, false /* = pointer_mode_host */,
974                         elem_count, GpuComplex(GpuMemory(x)), incx,
975                         GpuMemoryMutable(result));
976 }
977 
978 bool CUDABlas::DoBlasIamax(Stream *stream, uint64 elem_count,
979                            const DeviceMemory<std::complex<double>> &x,
980                            int incx, DeviceMemory<int> *result) {
981   return DoBlasInternal(cublasIzamax, stream, false /* = pointer_mode_host */,
982                         elem_count, GpuComplex(GpuMemory(x)), incx,
983                         GpuMemoryMutable(result));
984 }
985 
986 bool CUDABlas::DoBlasIamin(Stream *stream, uint64 elem_count,
987                            const DeviceMemory<float> &x, int incx,
988                            DeviceMemory<int> *result) {
989   return DoBlasInternal(cublasIsamin, stream, false /* = pointer_mode_host */,
990                         elem_count, GpuComplex(GpuMemory(x)), incx,
991                         GpuMemoryMutable(result));
992 }
993 
994 bool CUDABlas::DoBlasIamin(Stream *stream, uint64 elem_count,
995                            const DeviceMemory<double> &x, int incx,
996                            DeviceMemory<int> *result) {
997   return DoBlasInternal(cublasIdamin, stream, false /* = pointer_mode_host */,
998                         elem_count, GpuComplex(GpuMemory(x)), incx,
999                         GpuMemoryMutable(result));
1000 }
1001 
1002 bool CUDABlas::DoBlasIamin(Stream *stream, uint64 elem_count,
1003                            const DeviceMemory<std::complex<float>> &x, int incx,
1004                            DeviceMemory<int> *result) {
1005   return DoBlasInternal(cublasIcamin, stream, false /* = pointer_mode_host */,
1006                         elem_count, GpuComplex(GpuMemory(x)), incx,
1007                         GpuMemoryMutable(result));
1008 }
1009 
1010 bool CUDABlas::DoBlasIamin(Stream *stream, uint64 elem_count,
1011                            const DeviceMemory<std::complex<double>> &x,
1012                            int incx, DeviceMemory<int> *result) {
1013   return DoBlasInternal(cublasIzamin, stream, false /* = pointer_mode_host */,
1014                         elem_count, GpuComplex(GpuMemory(x)), incx,
1015                         GpuMemoryMutable(result));
1016 }
1017 
1018 bool CUDABlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
1019                           uint64 n, uint64 kl, uint64 ku, float alpha,
1020                           const DeviceMemory<float> &a, int lda,
1021                           const DeviceMemory<float> &x, int incx, float beta,
1022                           DeviceMemory<float> *y, int incy) {
1023   return DoBlasInternal(cublasSgbmv, stream, true /* = pointer_mode_host */,
1024                         CUDABlasTranspose(trans), m, n, kl, ku, &alpha,
1025                         GpuMemory(a), lda, GpuMemory(x), incx, &beta,
1026                         GpuMemoryMutable(y), incy);
1027 }
1028 
1029 bool CUDABlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
1030                           uint64 n, uint64 kl, uint64 ku, double alpha,
1031                           const DeviceMemory<double> &a, int lda,
1032                           const DeviceMemory<double> &x, int incx, double beta,
1033                           DeviceMemory<double> *y, int incy) {
1034   return DoBlasInternal(cublasDgbmv, stream, true /* = pointer_mode_host */,
1035                         CUDABlasTranspose(trans), m, n, kl, ku, &alpha,
1036                         GpuMemory(a), lda, GpuMemory(x), incx, &beta,
1037                         GpuMemoryMutable(y), incy);
1038 }
1039 
1040 bool CUDABlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
1041                           uint64 n, uint64 kl, uint64 ku,
1042                           std::complex<float> alpha,
1043                           const DeviceMemory<std::complex<float>> &a, int lda,
1044                           const DeviceMemory<std::complex<float>> &x, int incx,
1045                           std::complex<float> beta,
1046                           DeviceMemory<std::complex<float>> *y, int incy) {
1047   auto cb_alpha = GpuComplexValue(alpha);
1048   auto cb_beta = GpuComplexValue(beta);
1049   return DoBlasInternal(cublasCgbmv, stream, true /* = pointer_mode_host */,
1050                         CUDABlasTranspose(trans), m, n, kl, ku,
1051                         GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
1052                         GpuComplex(GpuMemory(x)), incx, GpuComplex(&cb_beta),
1053                         GpuComplex(GpuMemoryMutable(y)), incy);
1054 }
1055 
1056 bool CUDABlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
1057                           uint64 n, uint64 kl, uint64 ku,
1058                           std::complex<double> alpha,
1059                           const DeviceMemory<std::complex<double>> &a, int lda,
1060                           const DeviceMemory<std::complex<double>> &x, int incx,
1061                           std::complex<double> beta,
1062                           DeviceMemory<std::complex<double>> *y, int incy) {
1063   auto cb_alpha = GpuComplexValue(alpha);
1064   auto cb_beta = GpuComplexValue(beta);
1065   return DoBlasInternal(cublasZgbmv, stream, true /* = pointer_mode_host */,
1066                         CUDABlasTranspose(trans), m, n, kl, ku,
1067                         GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
1068                         GpuComplex(GpuMemory(x)), incx, GpuComplex(&cb_beta),
1069                         GpuComplex(GpuMemoryMutable(y)), incy);
1070 }
1071 
1072 bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
1073                           uint64 n, float alpha, const DeviceMemory<float> &a,
1074                           int lda, const DeviceMemory<float> &x, int incx,
1075                           float beta, DeviceMemory<float> *y, int incy) {
1076   return DoBlasInternal(cublasSgemv, stream, true /* = pointer_mode_host */,
1077                         CUDABlasTranspose(trans), m, n, &alpha, GpuMemory(a),
1078                         lda, GpuMemory(x), incx, &beta, GpuMemoryMutable(y),
1079                         incy);
1080 }
1081 
1082 bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
1083                           uint64 n, double alpha, const DeviceMemory<double> &a,
1084                           int lda, const DeviceMemory<double> &x, int incx,
1085                           double beta, DeviceMemory<double> *y, int incy) {
1086   return DoBlasInternal(cublasDgemv, stream, true /* = pointer_mode_host */,
1087                         CUDABlasTranspose(trans), m, n, &alpha, GpuMemory(a),
1088                         lda, GpuMemory(x), incx, &beta, GpuMemoryMutable(y),
1089                         incy);
1090 }
1091 
1092 bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
1093                           uint64 n, std::complex<float> alpha,
1094                           const DeviceMemory<std::complex<float>> &a, int lda,
1095                           const DeviceMemory<std::complex<float>> &x, int incx,
1096                           std::complex<float> beta,
1097                           DeviceMemory<std::complex<float>> *y, int incy) {
1098   auto cb_alpha = GpuComplexValue(alpha);
1099   auto cb_beta = GpuComplexValue(beta);
1100   return DoBlasInternal(cublasCgemv, stream, true /* = pointer_mode_host */,
1101                         CUDABlasTranspose(trans), m, n, GpuComplex(&cb_alpha),
1102                         GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
1103                         incx, GpuComplex(&cb_beta),
1104                         GpuComplex(GpuMemoryMutable(y)), incy);
1105 }
1106 
1107 bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
1108                           uint64 n, std::complex<double> alpha,
1109                           const DeviceMemory<std::complex<double>> &a, int lda,
1110                           const DeviceMemory<std::complex<double>> &x, int incx,
1111                           std::complex<double> beta,
1112                           DeviceMemory<std::complex<double>> *y, int incy) {
1113   auto cb_alpha = GpuComplexValue(alpha);
1114   auto cb_beta = GpuComplexValue(beta);
1115   return DoBlasInternal(cublasZgemv, stream, true /* = pointer_mode_host */,
1116                         CUDABlasTranspose(trans), m, n, GpuComplex(&cb_alpha),
1117                         GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
1118                         incx, GpuComplex(&cb_beta),
1119                         GpuComplex(GpuMemoryMutable(y)), incy);
1120 }
1121 
1122 bool CUDABlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha,
1123                          const DeviceMemory<float> &x, int incx,
1124                          const DeviceMemory<float> &y, int incy,
1125                          DeviceMemory<float> *a, int lda) {
1126   return DoBlasInternal(cublasSger, stream, true /* = pointer_mode_host */, m,
1127                         n, &alpha, GpuMemory(x), incx, GpuMemory(y), incy,
1128                         GpuMemoryMutable(a), lda);
1129 }
1130 
1131 bool CUDABlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha,
1132                          const DeviceMemory<double> &x, int incx,
1133                          const DeviceMemory<double> &y, int incy,
1134                          DeviceMemory<double> *a, int lda) {
1135   return DoBlasInternal(cublasDger, stream, true /* = pointer_mode_host */, m,
1136                         n, &alpha, GpuMemory(x), incx, GpuMemory(y), incy,
1137                         GpuMemoryMutable(a), lda);
1138 }
1139 
1140 bool CUDABlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n,
1141                           std::complex<float> alpha,
1142                           const DeviceMemory<std::complex<float>> &x, int incx,
1143                           const DeviceMemory<std::complex<float>> &y, int incy,
1144                           DeviceMemory<std::complex<float>> *a, int lda) {
1145   auto cb_alpha = GpuComplexValue(alpha);
1146   return DoBlasInternal(cublasCgerc, stream, true /* = pointer_mode_host */, m,
1147                         n, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(x)),
1148                         incx, GpuComplex(GpuMemory(y)), incy,
1149                         GpuComplex(GpuMemoryMutable(a)), lda);
1150 }
1151 
1152 bool CUDABlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n,
1153                           std::complex<double> alpha,
1154                           const DeviceMemory<std::complex<double>> &x, int incx,
1155                           const DeviceMemory<std::complex<double>> &y, int incy,
1156                           DeviceMemory<std::complex<double>> *a, int lda) {
1157   auto cb_alpha = GpuComplexValue(alpha);
1158   return DoBlasInternal(cublasZgerc, stream, true /* = pointer_mode_host */, m,
1159                         n, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(x)),
1160                         incx, GpuComplex(GpuMemory(y)), incy,
1161                         GpuComplex(GpuMemoryMutable(a)), lda);
1162 }
1163 
1164 bool CUDABlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n,
1165                           std::complex<float> alpha,
1166                           const DeviceMemory<std::complex<float>> &x, int incx,
1167                           const DeviceMemory<std::complex<float>> &y, int incy,
1168                           DeviceMemory<std::complex<float>> *a, int lda) {
1169   auto cb_alpha = GpuComplexValue(alpha);
1170   return DoBlasInternal(cublasCgeru, stream, true /* = pointer_mode_host */, m,
1171                         n, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(x)),
1172                         incx, GpuComplex(GpuMemory(y)), incy,
1173                         GpuComplex(GpuMemoryMutable(a)), lda);
1174 }
1175 
1176 bool CUDABlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n,
1177                           std::complex<double> alpha,
1178                           const DeviceMemory<std::complex<double>> &x, int incx,
1179                           const DeviceMemory<std::complex<double>> &y, int incy,
1180                           DeviceMemory<std::complex<double>> *a, int lda) {
1181   auto cb_alpha = GpuComplexValue(alpha);
1182   return DoBlasInternal(cublasZgeru, stream, true /* = pointer_mode_host */, m,
1183                         n, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(x)),
1184                         incx, GpuComplex(GpuMemory(y)), incy,
1185                         GpuComplex(GpuMemoryMutable(a)), lda);
1186 }
1187 
1188 bool CUDABlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1189                           uint64 k, std::complex<float> alpha,
1190                           const DeviceMemory<std::complex<float>> &a, int lda,
1191                           const DeviceMemory<std::complex<float>> &x, int incx,
1192                           std::complex<float> beta,
1193                           DeviceMemory<std::complex<float>> *y, int incy) {
1194   auto cb_alpha = GpuComplexValue(alpha);
1195   auto cb_beta = GpuComplexValue(beta);
1196   return DoBlasInternal(cublasChbmv, stream, true /* = pointer_mode_host */,
1197                         CUDABlasUpperLower(uplo), n, k, GpuComplex(&cb_alpha),
1198                         GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
1199                         incx, GpuComplex(&cb_beta),
1200                         GpuComplex(GpuMemoryMutable(y)), incy);
1201 }
1202 
1203 bool CUDABlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1204                           uint64 k, std::complex<double> alpha,
1205                           const DeviceMemory<std::complex<double>> &a, int lda,
1206                           const DeviceMemory<std::complex<double>> &x, int incx,
1207                           std::complex<double> beta,
1208                           DeviceMemory<std::complex<double>> *y, int incy) {
1209   auto cb_alpha = GpuComplexValue(alpha);
1210   auto cb_beta = GpuComplexValue(beta);
1211   return DoBlasInternal(cublasZhbmv, stream, true /* = pointer_mode_host */,
1212                         CUDABlasUpperLower(uplo), n, k, GpuComplex(&cb_alpha),
1213                         GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
1214                         incx, GpuComplex(&cb_beta),
1215                         GpuComplex(GpuMemoryMutable(y)), incy);
1216 }
1217 
1218 bool CUDABlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
1219                           std::complex<float> alpha,
1220                           const DeviceMemory<std::complex<float>> &a, int lda,
1221                           const DeviceMemory<std::complex<float>> &x, int incx,
1222                           std::complex<float> beta,
1223                           DeviceMemory<std::complex<float>> *y, int incy) {
1224   auto cb_alpha = GpuComplexValue(alpha);
1225   auto cb_beta = GpuComplexValue(beta);
1226   return DoBlasInternal(cublasChemv, stream, true /* = pointer_mode_host */,
1227                         CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
1228                         GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
1229                         incx, GpuComplex(&cb_beta),
1230                         GpuComplex(GpuMemoryMutable(y)), incy);
1231 }
1232 
1233 bool CUDABlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
1234                           std::complex<double> alpha,
1235                           const DeviceMemory<std::complex<double>> &a, int lda,
1236                           const DeviceMemory<std::complex<double>> &x, int incx,
1237                           std::complex<double> beta,
1238                           DeviceMemory<std::complex<double>> *y, int incy) {
1239   auto cb_alpha = GpuComplexValue(alpha);
1240   auto cb_beta = GpuComplexValue(beta);
1241   return DoBlasInternal(cublasZhemv, stream, true /* = pointer_mode_host */,
1242                         CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
1243                         GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
1244                         incx, GpuComplex(&cb_beta),
1245                         GpuComplex(GpuMemoryMutable(y)), incy);
1246 }
1247 
1248 bool CUDABlas::DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
1249                          float alpha,
1250                          const DeviceMemory<std::complex<float>> &x, int incx,
1251                          DeviceMemory<std::complex<float>> *a, int lda) {
1252   return DoBlasInternal(cublasCher, stream, true /* = pointer_mode_host */,
1253                         CUDABlasUpperLower(uplo), n, &alpha,
1254                         GpuComplex(GpuMemory(x)), incx,
1255                         GpuComplex(GpuMemoryMutable(a)), lda);
1256 }
1257 
1258 bool CUDABlas::DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
1259                          double alpha,
1260                          const DeviceMemory<std::complex<double>> &x, int incx,
1261                          DeviceMemory<std::complex<double>> *a, int lda) {
1262   return DoBlasInternal(cublasZher, stream, true /* = pointer_mode_host */,
1263                         CUDABlasUpperLower(uplo), n, &alpha,
1264                         GpuComplex(GpuMemory(x)), incx,
1265                         GpuComplex(GpuMemoryMutable(a)), lda);
1266 }
1267 
1268 bool CUDABlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
1269                           std::complex<float> alpha,
1270                           const DeviceMemory<std::complex<float>> &x, int incx,
1271                           const DeviceMemory<std::complex<float>> &y, int incy,
1272                           DeviceMemory<std::complex<float>> *a, int lda) {
1273   auto cb_alpha = GpuComplexValue(alpha);
1274   return DoBlasInternal(cublasCher2, stream, true /* = pointer_mode_host */,
1275                         CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
1276                         GpuComplex(GpuMemory(x)), incx,
1277                         GpuComplex(GpuMemory(y)), incy,
1278                         GpuComplex(GpuMemoryMutable(a)), lda);
1279 }
1280 
1281 bool CUDABlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
1282                           std::complex<double> alpha,
1283                           const DeviceMemory<std::complex<double>> &x, int incx,
1284                           const DeviceMemory<std::complex<double>> &y, int incy,
1285                           DeviceMemory<std::complex<double>> *a, int lda) {
1286   auto cb_alpha = GpuComplexValue(alpha);
1287   return DoBlasInternal(cublasZher2, stream, true /* = pointer_mode_host */,
1288                         CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
1289                         GpuComplex(GpuMemory(x)), incx,
1290                         GpuComplex(GpuMemory(y)), incy,
1291                         GpuComplex(GpuMemoryMutable(a)), lda);
1292 }
1293 
1294 bool CUDABlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1295                           std::complex<float> alpha,
1296                           const DeviceMemory<std::complex<float>> &ap,
1297                           const DeviceMemory<std::complex<float>> &x, int incx,
1298                           std::complex<float> beta,
1299                           DeviceMemory<std::complex<float>> *y, int incy) {
1300   auto cb_alpha = GpuComplexValue(alpha);
1301   auto cb_beta = GpuComplexValue(beta);
1302   return DoBlasInternal(cublasChpmv, stream, true /* = pointer_mode_host */,
1303                         CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
1304                         GpuComplex(GpuMemory(ap)), GpuComplex(GpuMemory(x)),
1305                         incx, GpuComplex(&cb_beta),
1306                         GpuComplex(GpuMemoryMutable(y)), incy);
1307 }
1308 
1309 bool CUDABlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1310                           std::complex<double> alpha,
1311                           const DeviceMemory<std::complex<double>> &ap,
1312                           const DeviceMemory<std::complex<double>> &x, int incx,
1313                           std::complex<double> beta,
1314                           DeviceMemory<std::complex<double>> *y, int incy) {
1315   auto cb_alpha = GpuComplexValue(alpha);
1316   auto cb_beta = GpuComplexValue(beta);
1317   return DoBlasInternal(cublasZhpmv, stream, true /* = pointer_mode_host */,
1318                         CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
1319                         GpuComplex(GpuMemory(ap)), GpuComplex(GpuMemory(x)),
1320                         incx, GpuComplex(&cb_beta),
1321                         GpuComplex(GpuMemoryMutable(y)), incy);
1322 }
1323 
1324 bool CUDABlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
1325                          float alpha,
1326                          const DeviceMemory<std::complex<float>> &x, int incx,
1327                          DeviceMemory<std::complex<float>> *ap) {
1328   return DoBlasInternal(cublasChpr, stream, true /* = pointer_mode_host */,
1329                         CUDABlasUpperLower(uplo), n, &alpha,
1330                         GpuComplex(GpuMemory(x)), incx,
1331                         GpuComplex(GpuMemoryMutable(ap)));
1332 }
1333 
1334 bool CUDABlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
1335                          double alpha,
1336                          const DeviceMemory<std::complex<double>> &x, int incx,
1337                          DeviceMemory<std::complex<double>> *ap) {
1338   return DoBlasInternal(cublasZhpr, stream, true /* = pointer_mode_host */,
1339                         CUDABlasUpperLower(uplo), n, &alpha,
1340                         GpuComplex(GpuMemory(x)), incx,
1341                         GpuComplex(GpuMemoryMutable(ap)));
1342 }
1343 
1344 bool CUDABlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1345                           std::complex<float> alpha,
1346                           const DeviceMemory<std::complex<float>> &x, int incx,
1347                           const DeviceMemory<std::complex<float>> &y, int incy,
1348                           DeviceMemory<std::complex<float>> *ap) {
1349   auto cb_alpha = GpuComplexValue(alpha);
1350   return DoBlasInternal(cublasChpr2, stream, true /* = pointer_mode_host */,
1351                         CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
1352                         GpuComplex(GpuMemory(x)), incx,
1353                         GpuComplex(GpuMemory(y)), incy,
1354                         GpuComplex(GpuMemoryMutable(ap)));
1355 }
1356 
1357 bool CUDABlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1358                           std::complex<double> alpha,
1359                           const DeviceMemory<std::complex<double>> &x, int incx,
1360                           const DeviceMemory<std::complex<double>> &y, int incy,
1361                           DeviceMemory<std::complex<double>> *ap) {
1362   auto cb_alpha = GpuComplexValue(alpha);
1363   return DoBlasInternal(cublasZhpr2, stream, true /* = pointer_mode_host */,
1364                         CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
1365                         GpuComplex(GpuMemory(x)), incx,
1366                         GpuComplex(GpuMemory(y)), incy,
1367                         GpuComplex(GpuMemoryMutable(ap)));
1368 }
1369 
1370 bool CUDABlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1371                           uint64 k, float alpha, const DeviceMemory<float> &a,
1372                           int lda, const DeviceMemory<float> &x, int incx,
1373                           float beta, DeviceMemory<float> *y, int incy) {
1374   return DoBlasInternal(cublasSsbmv, stream, true /* = pointer_mode_host */,
1375                         CUDABlasUpperLower(uplo), n, k, &alpha, GpuMemory(a),
1376                         lda, GpuMemory(x), incx, &beta, GpuMemoryMutable(y),
1377                         incy);
1378 }
1379 
1380 bool CUDABlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1381                           uint64 k, double alpha, const DeviceMemory<double> &a,
1382                           int lda, const DeviceMemory<double> &x, int incx,
1383                           double beta, DeviceMemory<double> *y, int incy) {
1384   return DoBlasInternal(cublasDsbmv, stream, true /* = pointer_mode_host */,
1385                         CUDABlasUpperLower(uplo), n, k, &alpha, GpuMemory(a),
1386                         lda, GpuMemory(x), incx, &beta, GpuMemoryMutable(y),
1387                         incy);
1388 }
1389 
1390 bool CUDABlas::DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1391                           float alpha, const DeviceMemory<float> &ap,
1392                           const DeviceMemory<float> &x, int incx, float beta,
1393                           DeviceMemory<float> *y, int incy) {
1394   return DoBlasInternal(cublasSspmv, stream, true /* = pointer_mode_host */,
1395                         CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(ap),
1396                         GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy);
1397 }
1398 
1399 bool CUDABlas::DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1400                           double alpha, const DeviceMemory<double> &ap,
1401                           const DeviceMemory<double> &x, int incx, double beta,
1402                           DeviceMemory<double> *y, int incy) {
1403   return DoBlasInternal(cublasDspmv, stream, true /* = pointer_mode_host */,
1404                         CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(ap),
1405                         GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy);
1406 }
1407 
1408 bool CUDABlas::DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
1409                          float alpha, const DeviceMemory<float> &x, int incx,
1410                          DeviceMemory<float> *ap) {
1411   return DoBlasInternal(cublasSspr, stream, true /* = pointer_mode_host */,
1412                         CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1413                         GpuMemoryMutable(ap));
1414 }
1415 
1416 bool CUDABlas::DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
1417                          double alpha, const DeviceMemory<double> &x, int incx,
1418                          DeviceMemory<double> *ap) {
1419   return DoBlasInternal(cublasDspr, stream, true /* = pointer_mode_host */,
1420                         CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1421                         GpuMemoryMutable(ap));
1422 }
1423 
1424 bool CUDABlas::DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1425                           float alpha, const DeviceMemory<float> &x, int incx,
1426                           const DeviceMemory<float> &y, int incy,
1427                           DeviceMemory<float> *ap) {
1428   return DoBlasInternal(cublasSspr2, stream, true /* = pointer_mode_host */,
1429                         CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1430                         GpuMemory(y), incy, GpuMemoryMutable(ap));
1431 }
1432 
1433 bool CUDABlas::DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1434                           double alpha, const DeviceMemory<double> &x, int incx,
1435                           const DeviceMemory<double> &y, int incy,
1436                           DeviceMemory<double> *ap) {
1437   return DoBlasInternal(cublasDspr2, stream, true /* = pointer_mode_host */,
1438                         CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1439                         GpuMemory(y), incy, GpuMemoryMutable(ap));
1440 }
1441 
1442 bool CUDABlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
1443                           float alpha, const DeviceMemory<float> &a, int lda,
1444                           const DeviceMemory<float> &x, int incx, float beta,
1445                           DeviceMemory<float> *y, int incy) {
1446   return DoBlasInternal(cublasSsymv, stream, true /* = pointer_mode_host */,
1447                         CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(a), lda,
1448                         GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy);
1449 }
1450 
1451 bool CUDABlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
1452                           double alpha, const DeviceMemory<double> &a, int lda,
1453                           const DeviceMemory<double> &x, int incx, double beta,
1454                           DeviceMemory<double> *y, int incy) {
1455   return DoBlasInternal(cublasDsymv, stream, true /* = pointer_mode_host */,
1456                         CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(a), lda,
1457                         GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy);
1458 }
1459 
1460 bool CUDABlas::DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
1461                          float alpha, const DeviceMemory<float> &x, int incx,
1462                          DeviceMemory<float> *a, int lda) {
1463   return DoBlasInternal(cublasSsyr, stream, true /* = pointer_mode_host */,
1464                         CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1465                         GpuMemoryMutable(a), lda);
1466 }
1467 
1468 bool CUDABlas::DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
1469                          double alpha, const DeviceMemory<double> &x, int incx,
1470                          DeviceMemory<double> *a, int lda) {
1471   return DoBlasInternal(cublasDsyr, stream, true /* = pointer_mode_host */,
1472                         CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1473                         GpuMemoryMutable(a), lda);
1474 }
1475 
1476 bool CUDABlas::DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1477                           float alpha, const DeviceMemory<float> &x, int incx,
1478                           const DeviceMemory<float> &y, int incy,
1479                           DeviceMemory<float> *a, int lda) {
1480   return DoBlasInternal(cublasSsyr2, stream, true /* = pointer_mode_host */,
1481                         CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1482                         GpuMemory(y), incy, GpuMemoryMutable(a), lda);
1483 }
1484 
1485 bool CUDABlas::DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1486                           double alpha, const DeviceMemory<double> &x, int incx,
1487                           const DeviceMemory<double> &y, int incy,
1488                           DeviceMemory<double> *a, int lda) {
1489   return DoBlasInternal(cublasDsyr2, stream, true /* = pointer_mode_host */,
1490                         CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1491                         GpuMemory(y), incy, GpuMemoryMutable(a), lda);
1492 }
1493 
1494 bool CUDABlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
1495                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1496                           uint64 k, const DeviceMemory<float> &a, int lda,
1497                           DeviceMemory<float> *x, int incx) {
1498   return DoBlasInternal(cublasStbmv, stream, true /* = pointer_mode_host */,
1499                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1500                         CUDABlasDiagonal(diag), n, k, GpuMemory(a), lda,
1501                         GpuMemoryMutable(x), incx);
1502 }
1503 
1504 bool CUDABlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
1505                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1506                           uint64 k, const DeviceMemory<double> &a, int lda,
1507                           DeviceMemory<double> *x, int incx) {
1508   return DoBlasInternal(cublasDtbmv, stream, true /* = pointer_mode_host */,
1509                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1510                         CUDABlasDiagonal(diag), n, k, GpuMemory(a), lda,
1511                         GpuMemoryMutable(x), incx);
1512 }
1513 
1514 bool CUDABlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
1515                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1516                           uint64 k, const DeviceMemory<std::complex<float>> &a,
1517                           int lda, DeviceMemory<std::complex<float>> *x,
1518                           int incx) {
1519   return DoBlasInternal(cublasCtbmv, stream, true /* = pointer_mode_host */,
1520                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1521                         CUDABlasDiagonal(diag), n, k, GpuComplex(GpuMemory(a)),
1522                         lda, GpuComplex(GpuMemoryMutable(x)), incx);
1523 }
1524 
1525 bool CUDABlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
1526                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1527                           uint64 k, const DeviceMemory<std::complex<double>> &a,
1528                           int lda, DeviceMemory<std::complex<double>> *x,
1529                           int incx) {
1530   return DoBlasInternal(cublasZtbmv, stream, true /* = pointer_mode_host */,
1531                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1532                         CUDABlasDiagonal(diag), n, k, GpuComplex(GpuMemory(a)),
1533                         lda, GpuComplex(GpuMemoryMutable(x)), incx);
1534 }
1535 
1536 bool CUDABlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
1537                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1538                           uint64 k, const DeviceMemory<float> &a, int lda,
1539                           DeviceMemory<float> *x, int incx) {
1540   return DoBlasInternal(cublasStbsv, stream, true /* = pointer_mode_host */,
1541                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1542                         CUDABlasDiagonal(diag), n, k, GpuMemory(a), lda,
1543                         GpuMemoryMutable(x), incx);
1544 }
1545 
1546 bool CUDABlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
1547                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1548                           uint64 k, const DeviceMemory<double> &a, int lda,
1549                           DeviceMemory<double> *x, int incx) {
1550   return DoBlasInternal(cublasDtbsv, stream, true /* = pointer_mode_host */,
1551                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1552                         CUDABlasDiagonal(diag), n, k, GpuMemory(a), lda,
1553                         GpuMemoryMutable(x), incx);
1554 }
1555 
1556 bool CUDABlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
1557                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1558                           uint64 k, const DeviceMemory<std::complex<float>> &a,
1559                           int lda, DeviceMemory<std::complex<float>> *x,
1560                           int incx) {
1561   return DoBlasInternal(cublasCtbsv, stream, true /* = pointer_mode_host */,
1562                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1563                         CUDABlasDiagonal(diag), n, k, GpuComplex(GpuMemory(a)),
1564                         lda, GpuComplex(GpuMemoryMutable(x)), incx);
1565 }
1566 
1567 bool CUDABlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
1568                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1569                           uint64 k, const DeviceMemory<std::complex<double>> &a,
1570                           int lda, DeviceMemory<std::complex<double>> *x,
1571                           int incx) {
1572   return DoBlasInternal(cublasZtbsv, stream, true /* = pointer_mode_host */,
1573                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1574                         CUDABlasDiagonal(diag), n, k, GpuComplex(GpuMemory(a)),
1575                         lda, GpuComplex(GpuMemoryMutable(x)), incx);
1576 }
1577 
1578 bool CUDABlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
1579                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1580                           const DeviceMemory<float> &ap, DeviceMemory<float> *x,
1581                           int incx) {
1582   return DoBlasInternal(cublasStpmv, stream, true /* = pointer_mode_host */,
1583                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1584                         CUDABlasDiagonal(diag), n, GpuMemory(ap),
1585                         GpuMemoryMutable(x), incx);
1586 }
1587 
1588 bool CUDABlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
1589                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1590                           const DeviceMemory<double> &ap,
1591                           DeviceMemory<double> *x, int incx) {
1592   return DoBlasInternal(cublasDtpmv, stream, true /* = pointer_mode_host */,
1593                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1594                         CUDABlasDiagonal(diag), n, GpuMemory(ap),
1595                         GpuMemoryMutable(x), incx);
1596 }
1597 
1598 bool CUDABlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
1599                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1600                           const DeviceMemory<std::complex<float>> &ap,
1601                           DeviceMemory<std::complex<float>> *x, int incx) {
1602   return DoBlasInternal(cublasCtpmv, stream, true /* = pointer_mode_host */,
1603                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1604                         CUDABlasDiagonal(diag), n, GpuComplex(GpuMemory(ap)),
1605                         GpuComplex(GpuMemoryMutable(x)), incx);
1606 }
1607 
1608 bool CUDABlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
1609                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1610                           const DeviceMemory<std::complex<double>> &ap,
1611                           DeviceMemory<std::complex<double>> *x, int incx) {
1612   return DoBlasInternal(cublasZtpmv, stream, true /* = pointer_mode_host */,
1613                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1614                         CUDABlasDiagonal(diag), n, GpuComplex(GpuMemory(ap)),
1615                         GpuComplex(GpuMemoryMutable(x)), incx);
1616 }
1617 
1618 bool CUDABlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
1619                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1620                           const DeviceMemory<float> &ap, DeviceMemory<float> *x,
1621                           int incx) {
1622   return DoBlasInternal(cublasStpsv, stream, true /* = pointer_mode_host */,
1623                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1624                         CUDABlasDiagonal(diag), n, GpuMemory(ap),
1625                         GpuMemoryMutable(x), incx);
1626 }
1627 
1628 bool CUDABlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
1629                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1630                           const DeviceMemory<double> &ap,
1631                           DeviceMemory<double> *x, int incx) {
1632   return DoBlasInternal(cublasDtpsv, stream, true /* = pointer_mode_host */,
1633                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1634                         CUDABlasDiagonal(diag), n, GpuMemory(ap),
1635                         GpuMemoryMutable(x), incx);
1636 }
1637 
1638 bool CUDABlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
1639                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1640                           const DeviceMemory<std::complex<float>> &ap,
1641                           DeviceMemory<std::complex<float>> *x, int incx) {
1642   return DoBlasInternal(cublasCtpsv, stream, true /* = pointer_mode_host */,
1643                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1644                         CUDABlasDiagonal(diag), n, GpuComplex(GpuMemory(ap)),
1645                         GpuComplex(GpuMemoryMutable(x)), incx);
1646 }
1647 
1648 bool CUDABlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
1649                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1650                           const DeviceMemory<std::complex<double>> &ap,
1651                           DeviceMemory<std::complex<double>> *x, int incx) {
1652   return DoBlasInternal(cublasZtpsv, stream, true /* = pointer_mode_host */,
1653                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1654                         CUDABlasDiagonal(diag), n, GpuComplex(GpuMemory(ap)),
1655                         GpuComplex(GpuMemoryMutable(x)), incx);
1656 }
1657 
1658 bool CUDABlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
1659                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1660                           const DeviceMemory<float> &a, int lda,
1661                           DeviceMemory<float> *x, int incx) {
1662   return DoBlasInternal(cublasStrmv, stream, true /* = pointer_mode_host */,
1663                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1664                         CUDABlasDiagonal(diag), n, GpuMemory(a), lda,
1665                         GpuMemoryMutable(x), incx);
1666 }
1667 
1668 bool CUDABlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
1669                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1670                           const DeviceMemory<double> &a, int lda,
1671                           DeviceMemory<double> *x, int incx) {
1672   return DoBlasInternal(cublasDtrmv, stream, true /* = pointer_mode_host */,
1673                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1674                         CUDABlasDiagonal(diag), n, GpuMemory(a), lda,
1675                         GpuMemoryMutable(x), incx);
1676 }
1677 
1678 bool CUDABlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
1679                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1680                           const DeviceMemory<std::complex<float>> &a, int lda,
1681                           DeviceMemory<std::complex<float>> *x, int incx) {
1682   return DoBlasInternal(cublasCtrmv, stream, true /* = pointer_mode_host */,
1683                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1684                         CUDABlasDiagonal(diag), n, GpuComplex(GpuMemory(a)),
1685                         lda, GpuComplex(GpuMemoryMutable(x)), incx);
1686 }
1687 
1688 bool CUDABlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
1689                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1690                           const DeviceMemory<std::complex<double>> &a, int lda,
1691                           DeviceMemory<std::complex<double>> *x, int incx) {
1692   return DoBlasInternal(cublasZtrmv, stream, true /* = pointer_mode_host */,
1693                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1694                         CUDABlasDiagonal(diag), n, GpuComplex(GpuMemory(a)),
1695                         lda, GpuComplex(GpuMemoryMutable(x)), incx);
1696 }
1697 
1698 bool CUDABlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1699                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1700                           const DeviceMemory<float> &a, int lda,
1701                           DeviceMemory<float> *x, int incx) {
1702   return DoBlasInternal(cublasStrsv, stream, true /* = pointer_mode_host */,
1703                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1704                         CUDABlasDiagonal(diag), n, GpuMemory(a), lda,
1705                         GpuMemoryMutable(x), incx);
1706 }
1707 
1708 bool CUDABlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1709                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1710                           const DeviceMemory<double> &a, int lda,
1711                           DeviceMemory<double> *x, int incx) {
1712   return DoBlasInternal(cublasDtrsv, stream, true /* = pointer_mode_host */,
1713                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1714                         CUDABlasDiagonal(diag), n, GpuMemory(a), lda,
1715                         GpuMemoryMutable(x), incx);
1716 }
1717 
1718 bool CUDABlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1719                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1720                           const DeviceMemory<std::complex<float>> &a, int lda,
1721                           DeviceMemory<std::complex<float>> *x, int incx) {
1722   return DoBlasInternal(cublasCtrsv, stream, true /* = pointer_mode_host */,
1723                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1724                         CUDABlasDiagonal(diag), n, GpuComplex(GpuMemory(a)),
1725                         lda, GpuComplex(GpuMemoryMutable(x)), incx);
1726 }
1727 
1728 bool CUDABlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1729                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1730                           const DeviceMemory<std::complex<double>> &a, int lda,
1731                           DeviceMemory<std::complex<double>> *x, int incx) {
1732   return DoBlasInternal(cublasZtrsv, stream, true /* = pointer_mode_host */,
1733                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1734                         CUDABlasDiagonal(diag), n, GpuComplex(GpuMemory(a)),
1735                         lda, GpuComplex(GpuMemoryMutable(x)), incx);
1736 }
1737 
1738 bool CUDABlas::DoBlasGemm(
1739     Stream *stream, blas::Transpose transa,
1740     blas::Transpose transb, uint64 m, uint64 n, uint64 k,
1741     float alpha, const DeviceMemory<Eigen::half> &a, int lda,
1742     const DeviceMemory<Eigen::half> &b, int ldb, float beta,
1743     DeviceMemory<Eigen::half> *c, int ldc) {
1744 #if CUDA_VERSION >= 7050
1745   VLOG(1) << absl::StrFormat(
1746       "doing cuBLAS SGEMM: at=%d bt=%d m=%u n=%u "
1747       "k=%u alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
1748       "c=%p ldc=%d",
1749       static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
1750       a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
1751   if (transa == blas::Transpose::kNoTranspose) {
1752     if (lda < static_cast<int64>(m)) {
1753       LOG(WARNING) << "GEMM lda was smaller than m (no transpose case); "
1754                       "precondition violation";
1755     }
1756   } else {
1757     if (lda < static_cast<int64>(k)) {
1758       LOG(WARNING) << "GEMM lda (" << lda << ") was smaller than k (" << k
1759                    << ") (transpose case); precondition violation";
1760     }
1761   }
1762   if (transb == blas::Transpose::kNoTranspose) {
1763     if (ldb < static_cast<int64>(k)) {
1764       LOG(WARNING) << "GEMM ldb (" << ldb << ") was smaller than k (" << k
1765                    << ") (no transpose case); precondition violation";
1766     }
1767   } else {
1768     if (ldb < static_cast<int64>(n)) {
1769       LOG(WARNING) << "GEMM ldb was smaller than n (transpose case); "
1770                       "precondition violation";
1771     }
1772   }
1773 
1774 #if CUDA_VERSION < 11000
1775   cublasMath_t math_type = CUBLAS_TENSOR_OP_MATH;
1776 #else
1777   cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
1778 #endif
1779 
1780   return DoBlasInternalImpl(
1781       cublasSgemmEx, stream, true /* = pointer_mode_host */,
1782       true /* = err_on_failure= */, math_type, CUDABlasTranspose(transa),
1783       CUDABlasTranspose(transb), m, n, k, &alpha, GpuMemory(a),
1784       SE_CUDA_DATA_HALF, lda, GpuMemory(b), SE_CUDA_DATA_HALF, ldb, &beta,
1785       GpuMemoryMutable(c), SE_CUDA_DATA_HALF, ldc);
1786 
1787 #else
1788   LOG(ERROR) << "fp16 sgemm is not implemented in this cuBLAS version "
1789              << "(need at least CUDA 7.5)";
1790   return false;
1791 #endif
1792 }
1793 
1794 bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
1795                           blas::Transpose transb, uint64 m, uint64 n, uint64 k,
1796                           float alpha, const DeviceMemory<float> &a, int lda,
1797                           const DeviceMemory<float> &b, int ldb, float beta,
1798                           DeviceMemory<float> *c, int ldc) {
1799   VLOG(1) << absl::StrFormat(
1800       "doing cuBLAS SGEMM: at=%d bt=%d m=%u n=%u "
1801       "k=%u alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
1802       "c=%p ldc=%d",
1803       static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
1804       a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
1805   if (transa == blas::Transpose::kNoTranspose) {
1806     if (lda < static_cast<int64>(m)) {
1807       LOG(WARNING) << "GEMM lda was smaller than m (no transpose case); "
1808                       "precondition violation";
1809     }
1810   } else {
1811     if (lda < static_cast<int64>(k)) {
1812       LOG(WARNING) << "GEMM lda (" << lda << ") was smaller than k (" << k
1813                    << ") (transpose case); precondition violation";
1814     }
1815   }
1816   if (transb == blas::Transpose::kNoTranspose) {
1817     if (ldb < static_cast<int64>(k)) {
1818       LOG(WARNING) << "GEMM ldb (" << ldb << ") was smaller than k (" << k
1819                    << ") (no transpose case); precondition violation";
1820     }
1821   } else {
1822     if (ldb < static_cast<int64>(n)) {
1823       LOG(WARNING) << "GEMM ldb was smaller than n (transpose case); "
1824                       "precondition violation";
1825     }
1826   }
1827 
1828 #if CUDA_VERSION < 11000
1829   cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
1830 #else
1831   cublasMath_t math_type = CUBLAS_TF32_TENSOR_OP_MATH;
1832   int cc_major, cc_minor;
1833   if (stream->parent()->GetDeviceDescription().cuda_compute_capability(
1834           &cc_major, &cc_minor) &&
1835       cc_major >= 8) {
1836     // TODO(reedwm): Remove or make this VLOG(1) once TensorFloat-32 is more
1837     // well tested.
1838     LOG_FIRST_N(INFO, 1) << "TensorFloat-32 will be used for the matrix "
1839                             "multiplication. This will only be logged once.";
1840   }
1841 #endif
1842 
1843   return DoBlasInternalImpl(
1844       cublasSgemm, stream, true /* = pointer_mode_host */,
1845       true /* = err_on_failure */, math_type, CUDABlasTranspose(transa),
1846       CUDABlasTranspose(transb), m, n, k, &alpha, GpuMemory(a), lda,
1847       GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc);
1848 }
1849 
1850 bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
1851                           blas::Transpose transb, uint64 m, uint64 n, uint64 k,
1852                           double alpha, const DeviceMemory<double> &a, int lda,
1853                           const DeviceMemory<double> &b, int ldb, double beta,
1854                           DeviceMemory<double> *c, int ldc) {
1855   return DoBlasInternal(cublasDgemm, stream, true /* = pointer_mode_host */,
1856                         CUDABlasTranspose(transa), CUDABlasTranspose(transb), m,
1857                         n, k, &alpha, GpuMemory(a), lda, GpuMemory(b), ldb,
1858                         &beta, GpuMemoryMutable(c), ldc);
1859 }
1860 
1861 bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
1862                           blas::Transpose transb, uint64 m, uint64 n, uint64 k,
1863                           std::complex<float> alpha,
1864                           const DeviceMemory<std::complex<float>> &a, int lda,
1865                           const DeviceMemory<std::complex<float>> &b, int ldb,
1866                           std::complex<float> beta,
1867                           DeviceMemory<std::complex<float>> *c, int ldc) {
1868   auto cb_alpha = GpuComplexValue(alpha);
1869   auto cb_beta = GpuComplexValue(beta);
1870   return DoBlasInternal(cublasCgemm, stream, true /* = pointer_mode_host */,
1871                         CUDABlasTranspose(transa), CUDABlasTranspose(transb), m,
1872                         n, k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)),
1873                         lda, GpuComplex(GpuMemory(b)), ldb,
1874                         GpuComplex(&cb_beta), GpuComplex(GpuMemoryMutable(c)),
1875                         ldc);
1876 }
1877 
1878 bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
1879                           blas::Transpose transb, uint64 m, uint64 n, uint64 k,
1880                           std::complex<double> alpha,
1881                           const DeviceMemory<std::complex<double>> &a, int lda,
1882                           const DeviceMemory<std::complex<double>> &b, int ldb,
1883                           std::complex<double> beta,
1884                           DeviceMemory<std::complex<double>> *c, int ldc) {
1885   auto cb_alpha = GpuComplexValue(alpha);
1886   auto cb_beta = GpuComplexValue(beta);
1887   return DoBlasInternal(cublasZgemm, stream, true /* = pointer_mode_host */,
1888                         CUDABlasTranspose(transa), CUDABlasTranspose(transb), m,
1889                         n, k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)),
1890                         lda, GpuComplex(GpuMemory(b)), ldb,
1891                         GpuComplex(&cb_beta), GpuComplex(GpuMemoryMutable(c)),
1892                         ldc);
1893 }
1894 
1895 bool CUDABlas::DoBlasGemvWithProfiling(
1896     Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha,
1897     const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
1898     int incx, float beta, DeviceMemory<float> *y, int incy,
1899     blas::ProfileResult *output_profile_result) {
1900   return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
1901                                      incx, beta, y, incy,
1902                                      output_profile_result);
1903 }
1904 
1905 bool CUDABlas::DoBlasGemvWithProfiling(
1906     Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha,
1907     const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
1908     int incx, double beta, DeviceMemory<double> *y, int incy,
1909     blas::ProfileResult *output_profile_result) {
1910   return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
1911                                      incx, beta, y, incy,
1912                                      output_profile_result);
1913 }
1914 
1915 bool CUDABlas::DoBlasGemvWithProfiling(
1916     Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
1917     std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a,
1918     int lda, const DeviceMemory<std::complex<float>> &x, int incx,
1919     std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
1920     blas::ProfileResult *output_profile_result) {
1921   return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
1922                                      incx, beta, y, incy,
1923                                      output_profile_result);
1924 }
1925 
1926 bool CUDABlas::DoBlasGemvWithProfiling(
1927     Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
1928     std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a,
1929     int lda, const DeviceMemory<std::complex<double>> &x, int incx,
1930     std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy,
1931     blas::ProfileResult *output_profile_result) {
1932   return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
1933                                      incx, beta, y, incy,
1934                                      output_profile_result);
1935 }
1936 
1937 bool CUDABlas::DoBlasGemmWithProfiling(
1938     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1939     uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
1940     int lda, const DeviceMemory<Eigen::half> &b, int ldb, float beta,
1941     DeviceMemory<Eigen::half> *c, int ldc,
1942     blas::ProfileResult *output_profile_result) {
1943   return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1944                                      lda, b, ldb, beta, c, ldc,
1945                                      output_profile_result);
1946 }
1947 
1948 bool CUDABlas::DoBlasGemmWithProfiling(
1949     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1950     uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
1951     const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
1952     int ldc, blas::ProfileResult *output_profile_result) {
1953   return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1954                                      lda, b, ldb, beta, c, ldc,
1955                                      output_profile_result);
1956 }
1957 
1958 bool CUDABlas::DoBlasGemmWithProfiling(
1959     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1960     uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
1961     const DeviceMemory<double> &b, int ldb, double beta,
1962     DeviceMemory<double> *c, int ldc,
1963     blas::ProfileResult *output_profile_result) {
1964   return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1965                                      lda, b, ldb, beta, c, ldc,
1966                                      output_profile_result);
1967 }
1968 
1969 bool CUDABlas::DoBlasGemmWithProfiling(
1970     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1971     uint64 n, uint64 k, std::complex<float> alpha,
1972     const DeviceMemory<std::complex<float>> &a, int lda,
1973     const DeviceMemory<std::complex<float>> &b, int ldb,
1974     std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
1975     blas::ProfileResult *output_profile_result) {
1976   return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1977                                      lda, b, ldb, beta, c, ldc,
1978                                      output_profile_result);
1979 }
1980 
1981 bool CUDABlas::DoBlasGemmWithProfiling(
1982     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1983     uint64 n, uint64 k, std::complex<double> alpha,
1984     const DeviceMemory<std::complex<double>> &a, int lda,
1985     const DeviceMemory<std::complex<double>> &b, int ldb,
1986     std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
1987     blas::ProfileResult *output_profile_result) {
1988   return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1989                                      lda, b, ldb, beta, c, ldc,
1990                                      output_profile_result);
1991 }
1992 
1993 template <typename T>
1994 bool CUDABlas::DoBlasGemvWithProfilingImpl(
1995     Stream *stream, blas::Transpose trans, uint64 m, uint64 n, const T &alpha,
1996     const DeviceMemory<T> &a, int lda, const DeviceMemory<T> &x, int incx,
1997     const T &beta, DeviceMemory<T> *y, int incy,
1998     blas::ProfileResult *output_profile_result) {
1999   std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
2000   if (output_profile_result != nullptr) {
2001     timer.reset(new GpuTimer(parent_));
2002     if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
2003       return false;
2004     }
2005   }
2006 
2007   // Call blasGemm
2008   bool result =
2009       DoBlasGemv(stream, trans, m, n, alpha, a, lda, x, incx, beta, y, incy);
2010 
2011   if (timer != nullptr && result) {
2012     // GpuTimer will CHECK-fail if we Stop() it while the stream is in an error
2013     // state.
2014     if (!timer->Stop(AsGpuStream(stream))) {
2015       return false;
2016     }
2017     output_profile_result->set_is_valid(true);
2018     output_profile_result->set_algorithm(blas::kDefaultBlasGemv);
2019     output_profile_result->set_elapsed_time_in_ms(
2020         timer->GetElapsedMilliseconds());
2021   }
2022   return result;
2023 }
2024 
2025 template <typename T, typename ParamType>
2026 bool CUDABlas::DoBlasGemmWithProfilingImpl(
2027     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2028     uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a,
2029     int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta,
2030     DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result) {
2031   std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
2032   if (output_profile_result != nullptr) {
2033     timer.reset(new GpuTimer(parent_));
2034     if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
2035       return false;
2036     }
2037   }
2038 
2039   // Call blasGemm
2040   bool result = DoBlasGemm(stream, transa, transb, m, n, k, alpha, a, lda, b,
2041                            ldb, beta, c, ldc);
2042 
2043   if (timer != nullptr && result) {
2044     // GpuTimer will CHECK-fail if we Stop() it while the stream is in an error
2045     // state.
2046     if (!timer->Stop(AsGpuStream(stream))) {
2047       return false;
2048     }
2049     output_profile_result->set_is_valid(true);
2050     output_profile_result->set_algorithm(blas::kDefaultBlasGemm);
2051     output_profile_result->set_elapsed_time_in_ms(
2052         timer->GetElapsedMilliseconds());
2053   }
2054   return result;
2055 }
2056 
2057 static bool UsesTensorOps(blas::AlgorithmType algo) {
2058 #if CUDA_VERSION >= 9000
2059   cublasGemmAlgo_t cublas_algo = static_cast<cublasGemmAlgo_t>(algo);
2060   return cublas_algo >= CUBLAS_GEMM_DEFAULT_TENSOR_OP;
2061 #else
2062   return false;
2063 #endif
2064 }
2065 
2066 template <typename InT, typename OutT, typename CompT>
2067 bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
2068     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2069     uint64 n, uint64 k, const HostOrDeviceScalar<CompT> &alpha,
2070     const DeviceMemory<InT> &a, int lda, const DeviceMemory<InT> &b, int ldb,
2071     const HostOrDeviceScalar<CompT> &beta, DeviceMemory<OutT> *c, int ldc,
2072     blas::ComputationType computation_type, blas::AlgorithmType algorithm,
2073     blas::ProfileResult *output_profile_result) {
2074   // GPUs < sm_50 don't support cublasGemmEx.
2075   int cc_major, cc_minor;
2076   if (stream->parent()->GetDeviceDescription().cuda_compute_capability(
2077           &cc_major, &cc_minor) &&
2078       cc_major < 5) {
2079     VLOG(2) << "DoBlasGemmWithAlgorithm returning false because sm" << cc_major
2080             << cc_minor << " devices don't support explicit gemm algorithms.";
2081     return false;
2082   }
2083 
2084   bool algo_uses_tensor_ops = UsesTensorOps(algorithm);
2085   cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
2086   if (algo_uses_tensor_ops) {
2087     if (cc_major < 7) {
2088       VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm "
2089               << algorithm
2090               << " uses tensor ops, but tensor ops are not available in sm"
2091               << cc_major << "X devices.";
2092       return false;
2093     } else if (std::is_same<InT, float>::value) {
2094 #if CUDA_VERSION < 11000
2095       VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm "
2096               << algorithm
2097               << " uses tensor ops, but tensor ops are not available for fp32"
2098               << " inputs.";
2099       return false;
2100 #else
2101       if (cc_major < 8) {
2102         VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm "
2103                 << algorithm
2104                 << " uses tensor ops, but tensor ops are not available in sm"
2105                 << cc_major << "X devices for float input types.";
2106         return false;
2107       } else if (!tensorflow::tensor_float_32_execution_enabled()) {
2108         VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm "
2109                 << algorithm
2110                 << " uses tensor ops, but tensor ops are disabled for fp32"
2111                 << " inputs.";
2112         return false;
2113       }
2114       math_type = CUBLAS_TF32_TENSOR_OP_MATH;
2115 #endif
2116     } else if (std::is_same<InT, Eigen::half>::value) {
2117 #if CUDA_VERSION < 11000
2118       math_type = CUBLAS_TENSOR_OP_MATH;
2119 #endif
2120     } else {
2121       VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm "
2122               << algorithm
2123               << " uses tensor ops, which are not supported for InT.";
2124       return false;
2125     }
2126   }
2127 
2128   // Either both 'alpha' and 'beta' need to be pointers to device memory, or
2129   // they need to be both host scalars.
2130   if (alpha.is_pointer() != beta.is_pointer()) {
2131     VLOG(2) << "DoBlasGemmWithAlgorithm returning false because one of `alpha` "
2132                "and `beta` is a pointer, but the other is not.";
2133     return false;
2134   }
2135 
2136   std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
2137   if (output_profile_result != nullptr) {
2138     timer.reset(new GpuTimer(parent_));
2139     if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
2140       VLOG(2) << "DoBlasGemmWithAlgorithm returning false because "
2141                  "output_profile_result was given, but we were unable to "
2142                  "create a GpuTimer.";
2143       return false;
2144     }
2145   }
2146 
2147   // Return false if we might be hitting a cuBLAS bug that produces the wrong
2148   // result. See nvbugs/2156201, b/79126339.
2149 #if CUDA_VERSION >= 9000 && CUDA_VERSION < 9020
2150   if ((algorithm == CUBLAS_GEMM_DEFAULT || algorithm >= CUBLAS_GEMM_ALGO13) &&
2151       std::max({m, n, k}) >= 2097153 && cc_major < 7) {
2152     VLOG(2) << "DoBlasGemmWithAlgorithm returning false to work around cudnn "
2153                "<9.2 bug with m, n, or k >= 2097153.  See b/79126339.";
2154     return false;
2155   }
2156 #endif
2157 
2158   cudaDataType_t cuda_in_type = CUDADataType<InT>::type;
2159   // Since we are converting 'algorithm' to cublasGemmAlgo_t by static_cast,
2160   // we do the following compile-time check on the default value:
2161   static_assert(blas::kDefaultGemmAlgo == CUBLAS_GEMM_DFALT, "");
2162   // If 'alpha' and 'beta' are host scalars and CompT is Eigen::half, we
2163   // essentially reinterpet_cast to __half, which is safe because Eigen::half
2164   // inherits from __half.
2165   bool result = DoBlasInternalImpl(
2166       AS_LAMBDA(cublasGemmEx), stream,
2167       /* pointer_mode_host = */ !alpha.is_pointer(), /*err_on_failure=*/false,
2168       math_type, CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
2169       alpha.is_pointer() ? GpuMemory(alpha.pointer()) : &alpha.value(),
2170       GpuMemory(a), cuda_in_type, lda, GpuMemory(b), cuda_in_type, ldb,
2171       beta.is_pointer() ? GpuMemory(beta.pointer()) : &beta.value(),
2172       GpuMemoryMutable(c), CUDADataType<OutT>::type, ldc,
2173       CUDAComputationType(computation_type),
2174       static_cast<cublasGemmAlgo_t>(algorithm));
2175 
2176   if (timer != nullptr && result) {
2177     // GpuTimer will CHECK-fail if we Stop() it while the stream is in an error
2178     // state.
2179     if (!timer->Stop(AsGpuStream(stream))) {
2180       VLOG(2) << "DoBlasGemmWithAlgorithm returning false; unable to stop "
2181                  "GpuTimer.";
2182       return false;
2183     }
2184     output_profile_result->set_is_valid(true);
2185     output_profile_result->set_algorithm(algorithm);
2186     output_profile_result->set_elapsed_time_in_ms(
2187         timer->GetElapsedMilliseconds());
2188   }
2189   return result;
2190 }
2191 
2192 bool CUDABlas::GetBlasGemmAlgorithms(
2193     std::vector<blas::AlgorithmType> *out_algorithms) {
2194   // cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx)
2195   // were first introduced in CUDA 8.
2196   //
2197   // Note that when CUDA version and compute capability is not sufficient, we
2198   // still return the out_algorithms. Caller needs to make sure that in this
2199   // case, the returned vector is empty.
2200   *out_algorithms = {
2201     CUBLAS_GEMM_DFALT,
2202     CUBLAS_GEMM_ALGO0,
2203     CUBLAS_GEMM_ALGO1,
2204     CUBLAS_GEMM_ALGO2,
2205     CUBLAS_GEMM_ALGO3,
2206     CUBLAS_GEMM_ALGO4,
2207     CUBLAS_GEMM_ALGO5,
2208     CUBLAS_GEMM_ALGO6,
2209     CUBLAS_GEMM_ALGO7,
2210 #if CUDA_VERSION >= 9000
2211     CUBLAS_GEMM_ALGO8,
2212     CUBLAS_GEMM_ALGO9,
2213     CUBLAS_GEMM_ALGO10,
2214     CUBLAS_GEMM_ALGO11,
2215     CUBLAS_GEMM_ALGO12,
2216     CUBLAS_GEMM_ALGO13,
2217     CUBLAS_GEMM_ALGO14,
2218     CUBLAS_GEMM_ALGO15,
2219     CUBLAS_GEMM_ALGO16,
2220     CUBLAS_GEMM_ALGO17,
2221     CUBLAS_GEMM_DFALT_TENSOR_OP,
2222     CUBLAS_GEMM_ALGO0_TENSOR_OP,
2223     CUBLAS_GEMM_ALGO1_TENSOR_OP,
2224     CUBLAS_GEMM_ALGO2_TENSOR_OP,
2225     CUBLAS_GEMM_ALGO3_TENSOR_OP,
2226     CUBLAS_GEMM_ALGO4_TENSOR_OP,
2227 #endif
2228 #if CUDA_VERSION >= 9020
2229     CUBLAS_GEMM_ALGO18,
2230     CUBLAS_GEMM_ALGO19,
2231     CUBLAS_GEMM_ALGO20,
2232     CUBLAS_GEMM_ALGO21,
2233     CUBLAS_GEMM_ALGO22,
2234     CUBLAS_GEMM_ALGO23,
2235     CUBLAS_GEMM_ALGO5_TENSOR_OP,
2236     CUBLAS_GEMM_ALGO6_TENSOR_OP,
2237     CUBLAS_GEMM_ALGO7_TENSOR_OP,
2238     CUBLAS_GEMM_ALGO8_TENSOR_OP,
2239     CUBLAS_GEMM_ALGO9_TENSOR_OP,
2240     CUBLAS_GEMM_ALGO10_TENSOR_OP,
2241     CUBLAS_GEMM_ALGO11_TENSOR_OP,
2242     CUBLAS_GEMM_ALGO12_TENSOR_OP,
2243     CUBLAS_GEMM_ALGO13_TENSOR_OP,
2244     CUBLAS_GEMM_ALGO14_TENSOR_OP,
2245     CUBLAS_GEMM_ALGO15_TENSOR_OP,
2246 #endif
2247   };
2248   return true;
2249 }
2250 
2251 bool CUDABlas::DoBlasGemmWithAlgorithm(
2252     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2253     uint64 n, uint64 k, const HostOrDeviceScalar<int> &alpha,
2254     const DeviceMemory<int8> &a, int lda, const DeviceMemory<int8> &b, int ldb,
2255     const HostOrDeviceScalar<int> &beta, DeviceMemory<int> *c, int ldc,
2256     blas::ComputationType computation_type, blas::AlgorithmType algorithm,
2257     blas::ProfileResult *output_profile_result) {
2258   return DoBlasGemmWithAlgorithmImpl(
2259       stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2260       computation_type, algorithm, output_profile_result);
2261 }
2262 
2263 bool CUDABlas::DoBlasGemmWithAlgorithm(
2264     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2265     uint64 n, uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha,
2266     const DeviceMemory<Eigen::half> &a, int lda,
2267     const DeviceMemory<Eigen::half> &b, int ldb,
2268     const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c,
2269     int ldc, blas::ComputationType computation_type,
2270     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
2271   if (computation_type == blas::ComputationType::kF32) {
2272     if (alpha.is_pointer() || beta.is_pointer()) {
2273       // We cannot easily convert a pointer to f16 memory to a pointer to f32
2274       // memory from here, so we don't support this for now.
2275       // TODO(akuegel): Investigate whether we can do the conversion before
2276       // calling DoBlasGemmWithAlgorithm.
2277       return false;
2278     }
2279     HostOrDeviceScalar<float> float_alpha(static_cast<float>(alpha.value()));
2280     HostOrDeviceScalar<float> float_beta(static_cast<float>(beta.value()));
2281     return DoBlasGemmWithAlgorithmImpl(
2282         stream, transa, transb, m, n, k, float_alpha, a, lda, b, ldb,
2283         float_beta, c, ldc, computation_type, algorithm, output_profile_result);
2284   }
2285 
2286   CHECK_EQ(computation_type, blas::ComputationType::kF16);
2287   return DoBlasGemmWithAlgorithmImpl(
2288       stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2289       computation_type, algorithm, output_profile_result);
2290 }
2291 
2292 bool CUDABlas::DoBlasGemmWithAlgorithm(
2293     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2294     uint64 n, uint64 k, const HostOrDeviceScalar<float> &alpha,
2295     const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,
2296     int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c,
2297     int ldc, blas::ComputationType computation_type,
2298     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
2299   return DoBlasGemmWithAlgorithmImpl(
2300       stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2301       computation_type, algorithm, output_profile_result);
2302 }
2303 
2304 bool CUDABlas::DoBlasGemmWithAlgorithm(
2305     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2306     uint64 n, uint64 k, const HostOrDeviceScalar<double> &alpha,
2307     const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,
2308     int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c,
2309     int ldc, blas::ComputationType computation_type,
2310     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
2311   return DoBlasGemmWithAlgorithmImpl(
2312       stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2313       computation_type, algorithm, output_profile_result);
2314 }
2315 
2316 bool CUDABlas::DoBlasGemmWithAlgorithm(
2317     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2318     uint64 n, uint64 k, const HostOrDeviceScalar<std::complex<float>> &alpha,
2319     const DeviceMemory<std::complex<float>> &a, int lda,
2320     const DeviceMemory<std::complex<float>> &b, int ldb,
2321     const HostOrDeviceScalar<std::complex<float>> &beta,
2322     DeviceMemory<std::complex<float>> *c, int ldc,
2323     blas::ComputationType computation_type, blas::AlgorithmType algorithm,
2324     blas::ProfileResult *output_profile_result) {
2325   return DoBlasGemmWithAlgorithmImpl(
2326       stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2327       computation_type, algorithm, output_profile_result);
2328 }
2329 
2330 bool CUDABlas::DoBlasGemmWithAlgorithm(
2331     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2332     uint64 n, uint64 k, const HostOrDeviceScalar<std::complex<double>> &alpha,
2333     const DeviceMemory<std::complex<double>> &a, int lda,
2334     const DeviceMemory<std::complex<double>> &b, int ldb,
2335     const HostOrDeviceScalar<std::complex<double>> &beta,
2336     DeviceMemory<std::complex<double>> *c, int ldc,
2337     blas::ComputationType computation_type, blas::AlgorithmType algorithm,
2338     blas::ProfileResult *output_profile_result) {
2339   return DoBlasGemmWithAlgorithmImpl(
2340       stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2341       computation_type, algorithm, output_profile_result);
2342 }
2343 
2344 template <typename T>
2345 struct HalfAsFloat {
2346   typedef T type;
2347 };
2348 
2349 template <>
2350 struct HalfAsFloat<Eigen::half> {
2351   typedef float type;
2352 };
2353 
2354 namespace {
2355 // pass-through for non-complex types that don't need conversion to
2356 // cublas-specific type.
2357 template <typename T>
2358 T inline GpuComplexValue(T v) {
2359   return v;
2360 }
2361 }  // namespace
2362 
2363 template <typename T, typename Scalar, typename FuncT>
2364 port::Status CUDABlas::DoBlasGemmBatchedInternal(
2365     FuncT cublas_func, Stream *stream, blas::Transpose transa,
2366     blas::Transpose transb, uint64 m, uint64 n, uint64 k, Scalar alpha,
2367     const port::ArraySlice<DeviceMemory<T> *> &a_ptrs_to_wrappers, int lda,
2368     const port::ArraySlice<DeviceMemory<T> *> &b_ptrs_to_wrappers, int ldb,
2369     Scalar beta, const port::ArraySlice<DeviceMemory<T> *> &c_ptrs_to_wrappers,
2370     int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
2371   std::vector<T *> a_raw_ptrs, b_raw_ptrs, c_raw_ptrs;
2372   for (int i = 0; i < batch_count; ++i) {
2373     a_raw_ptrs.push_back(static_cast<T *>(a_ptrs_to_wrappers[i]->opaque()));
2374     b_raw_ptrs.push_back(static_cast<T *>(b_ptrs_to_wrappers[i]->opaque()));
2375     c_raw_ptrs.push_back(static_cast<T *>(c_ptrs_to_wrappers[i]->opaque()));
2376   }
2377 
2378   typedef typename HalfAsFloat<typename GpuComplexT<T>::type>::type CUDA_T;
2379 
2380   const size_t size = batch_count * sizeof(CUDA_T *);
2381 
2382   // Device-side copy of pointers to matrices.
2383   DeviceMemory<CUDA_T *> a;
2384   DeviceMemory<CUDA_T *> b;
2385   DeviceMemory<CUDA_T *> c;
2386 
2387   // If temporary space is allocated for device-side copies of pointers to
2388   // matrices, that temporary space should not be freed until this function
2389   // returns. Although the values for these unique_ptrs are not set here, they
2390   // are declared at this scope so they will be destroyed when the function
2391   // returns.
2392   //
2393   // If a scratch allocator is provided, these pointers will not be used at all.
2394   std::unique_ptr<TemporaryDeviceMemory<CUDA_T *>> a_temporary;
2395   std::unique_ptr<TemporaryDeviceMemory<CUDA_T *>> b_temporary;
2396   std::unique_ptr<TemporaryDeviceMemory<CUDA_T *>> c_temporary;
2397 
2398   // Decide how to allocate device-side copy of pointers to matrices based on
2399   // whether a scratch allocator was passed.
2400   if (scratch_allocator != nullptr) {
2401     SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> a_bytes,
2402                         scratch_allocator->AllocateBytes(size));
2403     SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> b_bytes,
2404                         scratch_allocator->AllocateBytes(size));
2405     SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> c_bytes,
2406                         scratch_allocator->AllocateBytes(size));
2407     a = DeviceMemory<CUDA_T *>(a_bytes);
2408     b = DeviceMemory<CUDA_T *>(b_bytes);
2409     c = DeviceMemory<CUDA_T *>(c_bytes);
2410   } else {
2411     SE_ASSIGN_OR_RETURN(a_temporary,
2412                         stream->AllocateTemporaryArray<CUDA_T *>(batch_count));
2413     SE_ASSIGN_OR_RETURN(b_temporary,
2414                         stream->AllocateTemporaryArray<CUDA_T *>(batch_count));
2415     SE_ASSIGN_OR_RETURN(c_temporary,
2416                         stream->AllocateTemporaryArray<CUDA_T *>(batch_count));
2417     a = DeviceMemory<CUDA_T *>(*a_temporary->mutable_device_memory());
2418     b = DeviceMemory<CUDA_T *>(*b_temporary->mutable_device_memory());
2419     c = DeviceMemory<CUDA_T *>(*c_temporary->mutable_device_memory());
2420   }
2421 
2422   if (!stream->ThenMemcpy(&a, a_raw_ptrs.data(), size).ok() ||
2423       !stream->ThenMemcpy(&b, b_raw_ptrs.data(), size).ok() ||
2424       !stream->ThenMemcpy(&c, c_raw_ptrs.data(), size).ok()) {
2425     return port::Status(port::error::INTERNAL,
2426                         "failed to copy memory from host to device in "
2427                         "CUDABlas::DoBlasGemmBatched");
2428   }
2429 
2430   cudaDataType_t data_type = CUDADataType<T>::type;
2431 
2432 #if CUDA_VERSION >= 9010
2433   int cc_major, cc_minor;
2434   if (stream->parent()->GetDeviceDescription().cuda_compute_capability(
2435           &cc_major, &cc_minor) &&
2436       cc_major >= 5) {
2437     cublasMath_t math_type;
2438     cublasGemmAlgo_t algo;
2439     if (data_type == CUDA_R_16F) {
2440 #if CUDA_VERSION < 11000
2441       math_type = CUBLAS_TENSOR_OP_MATH;
2442 #else
2443       math_type = CUBLAS_DEFAULT_MATH;
2444 #endif
2445       algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
2446 #if CUBLAS_VER_MAJOR >= 11
2447     } else if (data_type == CUDA_R_32F) {
2448       // DoBlassInternalImpl will switch math_type back to CUBLAS_DEFAULT_MATH
2449       // if TensorFloat-32 is disabled.
2450       math_type = CUBLAS_TF32_TENSOR_OP_MATH;
2451       algo = tensorflow::tensor_float_32_execution_enabled()
2452                  ? CUBLAS_GEMM_DFALT_TENSOR_OP
2453                  : CUBLAS_GEMM_DFALT;
2454 #endif
2455     } else {
2456       math_type = CUBLAS_DEFAULT_MATH;
2457       algo = CUBLAS_GEMM_DFALT;
2458     }
2459     cudaDataType_t compute_type =
2460         (data_type == CUDA_R_16F ? CUDA_R_32F : data_type);
2461     const void **a_void_ptrs = reinterpret_cast<const void **>(
2462         const_cast<const CUDA_T **>(GpuMemory(a)));
2463     const void **b_void_ptrs = reinterpret_cast<const void **>(
2464         const_cast<const CUDA_T **>(GpuMemory(b)));
2465     void **c_void_ptrs =
2466         reinterpret_cast<void **>(const_cast<CUDA_T **>(GpuMemory(c)));
2467     bool ok;
2468     ok = DoBlasInternalImpl(
2469         AS_LAMBDA(cublasGemmBatchedEx), stream, true /* = pointer_mode_host */,
2470         true /* = err_on_failure */, math_type, CUDABlasTranspose(transa),
2471         CUDABlasTranspose(transb), m, n, k, &alpha, a_void_ptrs, data_type, lda,
2472         b_void_ptrs, data_type, ldb, &beta, c_void_ptrs, data_type, ldc,
2473         batch_count, compute_type, algo);
2474     if (ok) {
2475       return port::Status::OK();
2476     }
2477     return port::Status(port::error::INTERNAL,
2478                         "failed BLAS call, see log for details");
2479   }
2480 #endif
2481   // either CUDA_VERSION < 9.1 or SM < 5.0
2482   if (data_type != CUDA_R_16F) {
2483     auto cb_alpha = GpuComplexValue(alpha);
2484     auto cb_beta = GpuComplexValue(beta);
2485     bool ok = DoBlasInternal(
2486         cublas_func, stream, true /* = pointer_mode_host */,
2487         CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
2488         GpuComplex(&cb_alpha), const_cast<const CUDA_T **>(GpuMemory(a)), lda,
2489         const_cast<const CUDA_T **>(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
2490         const_cast<CUDA_T **>(GpuMemory(c)), ldc, batch_count);
2491     if (ok) {
2492       return port::Status::OK();
2493     }
2494     return port::Status(port::error::INTERNAL,
2495                         "failed BLAS call, see log for details");
2496   } else {
2497     // Fall back to a loop for fp16
2498     for (int b = 0; b < batch_count; ++b) {
2499       const DeviceMemory<T> &a_matrix = *a_ptrs_to_wrappers[b];
2500       const DeviceMemory<T> &b_matrix = *b_ptrs_to_wrappers[b];
2501       DeviceMemory<T> *c_matrix = c_ptrs_to_wrappers[b];
2502       bool ok = DoBlasGemm(stream, transa, transb, m, n, k, alpha, a_matrix,
2503                            lda, b_matrix, ldb, beta, c_matrix, ldc);
2504       if (!ok) {
2505         return port::Status(port::error::INTERNAL,
2506                             "failed BLAS call, see log for details");
2507       }
2508     }
2509     return port::Status::OK();
2510   }
2511 }
2512 
2513 bool CUDABlas::DoBlasGemmBatched(
2514     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2515     uint64 n, uint64 k, float alpha,
2516     const port::ArraySlice<DeviceMemory<Eigen::half> *> &a_array, int lda,
2517     const port::ArraySlice<DeviceMemory<Eigen::half> *> &b_array, int ldb,
2518     float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c_array,
2519     int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
2520   // Note: The func passed here (cublasSgemmBatched) is not actually called,
2521   // due to special handling of fp16 inside DoBlasGemmBatchedInternal.
2522   port::Status status = DoBlasGemmBatchedInternal(
2523       cublasSgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
2524       b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
2525   if (!status.ok()) {
2526     LOG(ERROR) << status;
2527   }
2528   return status.ok();
2529 }
2530 
2531 bool CUDABlas::DoBlasGemmBatched(
2532     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2533     uint64 n, uint64 k, float alpha,
2534     const port::ArraySlice<DeviceMemory<float> *> &a_array, int lda,
2535     const port::ArraySlice<DeviceMemory<float> *> &b_array, int ldb, float beta,
2536     const port::ArraySlice<DeviceMemory<float> *> &c_array, int ldc,
2537     int batch_count, ScratchAllocator *scratch_allocator) {
2538   port::Status status = DoBlasGemmBatchedInternal(
2539       cublasSgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
2540       b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
2541   if (!status.ok()) {
2542     LOG(ERROR) << status;
2543   }
2544   return status.ok();
2545 }
2546 
2547 bool CUDABlas::DoBlasGemmBatched(
2548     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2549     uint64 n, uint64 k, double alpha,
2550     const port::ArraySlice<DeviceMemory<double> *> &a_array, int lda,
2551     const port::ArraySlice<DeviceMemory<double> *> &b_array, int ldb,
2552     double beta, const port::ArraySlice<DeviceMemory<double> *> &c_array,
2553     int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
2554   port::Status status = DoBlasGemmBatchedInternal(
2555       cublasDgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
2556       b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
2557   if (!status.ok()) {
2558     LOG(ERROR) << status;
2559   }
2560   return status.ok();
2561 }
2562 
2563 bool CUDABlas::DoBlasGemmBatched(
2564     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2565     uint64 n, uint64 k, std::complex<float> alpha,
2566     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a_array,
2567     int lda,
2568     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b_array,
2569     int ldb, std::complex<float> beta,
2570     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c_array,
2571     int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
2572   port::Status status = DoBlasGemmBatchedInternal(
2573       cublasCgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
2574       b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
2575   if (!status.ok()) {
2576     LOG(ERROR) << status;
2577   }
2578   return status.ok();
2579 }
2580 
2581 bool CUDABlas::DoBlasGemmBatched(
2582     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2583     uint64 n, uint64 k, std::complex<double> alpha,
2584     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a_array,
2585     int lda,
2586     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b_array,
2587     int ldb, std::complex<double> beta,
2588     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c_array,
2589     int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
2590   port::Status status = DoBlasGemmBatchedInternal(
2591       cublasZgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
2592       b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
2593   if (!status.ok()) {
2594     LOG(ERROR) << status;
2595   }
2596   return status.ok();
2597 }
2598 
2599 bool CUDABlas::DoBlasGemmStridedBatched(
2600     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2601     uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
2602     int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb,
2603     int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc,
2604     int64 stride_c, int batch_count) {
2605 #if CUDA_VERSION >= 9010
2606   int cc_major, cc_minor;
2607   if (stream->parent()->GetDeviceDescription().cuda_compute_capability(
2608           &cc_major, &cc_minor) &&
2609       cc_major >= 5) {
2610     cublasGemmAlgo_t algo =
2611         (cc_major >= 7 ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT);
2612 #if CUDA_VERSION < 11000
2613     cublasMath_t math_type = CUBLAS_TENSOR_OP_MATH;
2614 #else
2615     cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
2616 #endif
2617     bool ok = DoBlasInternalImpl(
2618         AS_LAMBDA(cublasGemmStridedBatchedEx), stream,
2619         true /* = pointer_mode_host */, true /* = err_on_failure */, math_type,
2620         CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
2621         GpuMemory(a), CUDA_R_16F, lda, stride_a, GpuMemory(b), CUDA_R_16F, ldb,
2622         stride_b, &beta, GpuMemoryMutable(c), CUDA_R_16F, ldc, stride_c,
2623         batch_count, CUDA_R_32F, algo);
2624     if (ok) {
2625       return true;
2626     }
2627     LOG(ERROR) << "failed BLAS call, see log for details";
2628     return false;
2629   }
2630 #endif
2631   // Either CUDA_VERSION < 9.1 or SM < 5.0. Fall back to a loop.
2632   for (int batch = 0; batch < batch_count; ++batch) {
2633     const auto *a_matrix =
2634         reinterpret_cast<const __half *>(GpuMemory(a) + batch * stride_a);
2635     const auto *b_matrix =
2636         reinterpret_cast<const __half *>(GpuMemory(b) + batch * stride_b);
2637     auto *c_matrix =
2638         reinterpret_cast<__half *>(GpuMemoryMutable(c) + batch * stride_c);
2639     bool ok = DoBlasInternalImpl(
2640         cublasSgemmEx, stream, true /* = pointer_mode_host */,
2641         true /* = err_on_failure= */, CUBLAS_DEFAULT_MATH,
2642         CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
2643         a_matrix, SE_CUDA_DATA_HALF, lda, b_matrix, SE_CUDA_DATA_HALF, ldb,
2644         &beta, c_matrix, SE_CUDA_DATA_HALF, ldc);
2645     if (!ok) {
2646       LOG(ERROR) << "failed BLAS call, see log for details";
2647       return false;
2648     }
2649   }
2650   return true;
2651 }
2652 
2653 bool CUDABlas::DoBlasGemmStridedBatched(
2654     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2655     uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
2656     int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
2657     float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
2658     int batch_count) {
2659 #if CUDA_VERSION < 11000
2660   cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
2661 #else
2662   cublasMath_t math_type = CUBLAS_TF32_TENSOR_OP_MATH;
2663 #endif
2664   return DoBlasInternalImpl(
2665       cublasSgemmStridedBatched, stream, true /* = pointer_mode_host */,
2666       true /* = err_on_failure */, math_type, CUDABlasTranspose(transa),
2667       CUDABlasTranspose(transb), m, n, k, &alpha, GpuMemory(a), lda, stride_a,
2668       GpuMemory(b), ldb, stride_b, &beta, GpuMemoryMutable(c), ldc, stride_c,
2669       batch_count);
2670 }
2671 
2672 bool CUDABlas::DoBlasGemmStridedBatched(
2673     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2674     uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
2675     int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
2676     double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
2677     int batch_count) {
2678   return DoBlasInternal(
2679       cublasDgemmStridedBatched, stream, true /* = pointer_mode_host */,
2680       CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
2681       GpuMemory(a), lda, stride_a, GpuMemory(b), ldb, stride_b, &beta,
2682       GpuMemoryMutable(c), ldc, stride_c, batch_count);
2683 }
2684 
2685 bool CUDABlas::DoBlasGemmStridedBatched(
2686     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2687     uint64 n, uint64 k, std::complex<float> alpha,
2688     const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
2689     const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
2690     std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
2691     int64 stride_c, int batch_count) {
2692   auto cb_alpha = GpuComplexValue(alpha);
2693   auto cb_beta = GpuComplexValue(beta);
2694   return DoBlasInternal(
2695       cublasCgemmStridedBatched, stream, true /* = pointer_mode_host */,
2696       CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
2697       GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda, stride_a,
2698       GpuComplex(GpuMemory(b)), ldb, stride_b, GpuComplex(&cb_beta),
2699       GpuComplex(GpuMemoryMutable(c)), ldc, stride_c, batch_count);
2700 }
2701 
2702 bool CUDABlas::DoBlasGemmStridedBatched(
2703     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2704     uint64 n, uint64 k, std::complex<double> alpha,
2705     const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
2706     const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
2707     std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
2708     int64 stride_c, int batch_count) {
2709   auto cb_alpha = GpuComplexValue(alpha);
2710   auto cb_beta = GpuComplexValue(beta);
2711   return DoBlasInternal(
2712       cublasZgemmStridedBatched, stream, true /* = pointer_mode_host */,
2713       CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
2714       GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda, stride_a,
2715       GpuComplex(GpuMemory(b)), ldb, stride_b, GpuComplex(&cb_beta),
2716       GpuComplex(GpuMemoryMutable(c)), ldc, stride_c, batch_count);
2717 }
2718 
2719 bool CUDABlas::DoBlasHemm(Stream *stream, blas::Side side,
2720                           blas::UpperLower uplo, uint64 m, uint64 n,
2721                           std::complex<float> alpha,
2722                           const DeviceMemory<std::complex<float>> &a, int lda,
2723                           const DeviceMemory<std::complex<float>> &b, int ldb,
2724                           std::complex<float> beta,
2725                           DeviceMemory<std::complex<float>> *c, int ldc) {
2726   auto cb_alpha = GpuComplexValue(alpha);
2727   auto cb_beta = GpuComplexValue(beta);
2728   return DoBlasInternal(cublasChemm, stream, true /* = pointer_mode_host */,
2729                         CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n,
2730                         GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2731                         GpuComplex(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
2732                         GpuComplex(GpuMemoryMutable(c)), ldc);
2733 }
2734 
2735 bool CUDABlas::DoBlasHemm(Stream *stream, blas::Side side,
2736                           blas::UpperLower uplo, uint64 m, uint64 n,
2737                           std::complex<double> alpha,
2738                           const DeviceMemory<std::complex<double>> &a, int lda,
2739                           const DeviceMemory<std::complex<double>> &b, int ldb,
2740                           std::complex<double> beta,
2741                           DeviceMemory<std::complex<double>> *c, int ldc) {
2742   auto cb_alpha = GpuComplexValue(alpha);
2743   auto cb_beta = GpuComplexValue(beta);
2744   return DoBlasInternal(cublasZhemm, stream, true /* = pointer_mode_host */,
2745                         CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n,
2746                         GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2747                         GpuComplex(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
2748                         GpuComplex(GpuMemoryMutable(c)), ldc);
2749 }
2750 
2751 bool CUDABlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo,
2752                           blas::Transpose trans, uint64 n, uint64 k,
2753                           float alpha,
2754                           const DeviceMemory<std::complex<float>> &a, int lda,
2755                           float beta, DeviceMemory<std::complex<float>> *c,
2756                           int ldc) {
2757   return DoBlasInternal(cublasCherk, stream, true /* = pointer_mode_host */,
2758                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2759                         k, &alpha, GpuComplex(GpuMemory(a)), lda, &beta,
2760                         GpuComplex(GpuMemoryMutable(c)), ldc);
2761 }
2762 
2763 bool CUDABlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo,
2764                           blas::Transpose trans, uint64 n, uint64 k,
2765                           double alpha,
2766                           const DeviceMemory<std::complex<double>> &a, int lda,
2767                           double beta, DeviceMemory<std::complex<double>> *c,
2768                           int ldc) {
2769   return DoBlasInternal(cublasZherk, stream, true /* = pointer_mode_host */,
2770                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2771                         k, &alpha, GpuComplex(GpuMemory(a)), lda, &beta,
2772                         GpuComplex(GpuMemoryMutable(c)), ldc);
2773 }
2774 
2775 bool CUDABlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
2776                            blas::Transpose trans, uint64 n, uint64 k,
2777                            std::complex<float> alpha,
2778                            const DeviceMemory<std::complex<float>> &a, int lda,
2779                            const DeviceMemory<std::complex<float>> &b, int ldb,
2780                            float beta, DeviceMemory<std::complex<float>> *c,
2781                            int ldc) {
2782   auto cb_alpha = GpuComplexValue(alpha);
2783   return DoBlasInternal(cublasCher2k, stream, true /* = pointer_mode_host */,
2784                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2785                         k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2786                         GpuComplex(GpuMemory(b)), ldb, &beta,
2787                         GpuComplex(GpuMemoryMutable(c)), ldc);
2788 }
2789 
2790 bool CUDABlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
2791                            blas::Transpose trans, uint64 n, uint64 k,
2792                            std::complex<double> alpha,
2793                            const DeviceMemory<std::complex<double>> &a, int lda,
2794                            const DeviceMemory<std::complex<double>> &b, int ldb,
2795                            double beta, DeviceMemory<std::complex<double>> *c,
2796                            int ldc) {
2797   auto cb_alpha = GpuComplexValue(alpha);
2798   return DoBlasInternal(cublasZher2k, stream, true /* = pointer_mode_host */,
2799                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2800                         k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2801                         GpuComplex(GpuMemory(b)), ldb, &beta,
2802                         GpuComplex(GpuMemoryMutable(c)), ldc);
2803 }
2804 
2805 bool CUDABlas::DoBlasSymm(Stream *stream, blas::Side side,
2806                           blas::UpperLower uplo, uint64 m, uint64 n,
2807                           float alpha, const DeviceMemory<float> &a, int lda,
2808                           const DeviceMemory<float> &b, int ldb, float beta,
2809                           DeviceMemory<float> *c, int ldc) {
2810   return DoBlasInternal(cublasSsymm, stream, true /* = pointer_mode_host */,
2811                         CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n,
2812                         &alpha, GpuMemory(a), lda, GpuMemory(b), ldb, &beta,
2813                         GpuMemoryMutable(c), ldc);
2814 }
2815 
2816 bool CUDABlas::DoBlasSymm(Stream *stream, blas::Side side,
2817                           blas::UpperLower uplo, uint64 m, uint64 n,
2818                           double alpha, const DeviceMemory<double> &a, int lda,
2819                           const DeviceMemory<double> &b, int ldb, double beta,
2820                           DeviceMemory<double> *c, int ldc) {
2821   return DoBlasInternal(cublasDsymm, stream, true /* = pointer_mode_host */,
2822                         CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n,
2823                         &alpha, GpuMemory(a), lda, GpuMemory(b), ldb, &beta,
2824                         GpuMemoryMutable(c), ldc);
2825 }
2826 
2827 bool CUDABlas::DoBlasSymm(Stream *stream, blas::Side side,
2828                           blas::UpperLower uplo, uint64 m, uint64 n,
2829                           std::complex<float> alpha,
2830                           const DeviceMemory<std::complex<float>> &a, int lda,
2831                           const DeviceMemory<std::complex<float>> &b, int ldb,
2832                           std::complex<float> beta,
2833                           DeviceMemory<std::complex<float>> *c, int ldc) {
2834   auto cb_alpha = GpuComplexValue(alpha);
2835   auto cb_beta = GpuComplexValue(beta);
2836   return DoBlasInternal(cublasCsymm, stream, true /* = pointer_mode_host */,
2837                         CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n,
2838                         GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2839                         GpuComplex(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
2840                         GpuComplex(GpuMemoryMutable(c)), ldc);
2841 }
2842 
2843 bool CUDABlas::DoBlasSymm(Stream *stream, blas::Side side,
2844                           blas::UpperLower uplo, uint64 m, uint64 n,
2845                           std::complex<double> alpha,
2846                           const DeviceMemory<std::complex<double>> &a, int lda,
2847                           const DeviceMemory<std::complex<double>> &b, int ldb,
2848                           std::complex<double> beta,
2849                           DeviceMemory<std::complex<double>> *c, int ldc) {
2850   auto cb_alpha = GpuComplexValue(alpha);
2851   auto cb_beta = GpuComplexValue(beta);
2852   return DoBlasInternal(cublasZsymm, stream, true /* = pointer_mode_host */,
2853                         CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n,
2854                         GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2855                         GpuComplex(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
2856                         GpuComplex(GpuMemoryMutable(c)), ldc);
2857 }
2858 
2859 bool CUDABlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
2860                           blas::Transpose trans, uint64 n, uint64 k,
2861                           float alpha, const DeviceMemory<float> &a, int lda,
2862                           float beta, DeviceMemory<float> *c, int ldc) {
2863   return DoBlasInternal(cublasSsyrk, stream, true /* = pointer_mode_host */,
2864                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2865                         k, &alpha, GpuMemory(a), lda, &beta,
2866                         GpuMemoryMutable(c), ldc);
2867 }
2868 
2869 bool CUDABlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
2870                           blas::Transpose trans, uint64 n, uint64 k,
2871                           double alpha, const DeviceMemory<double> &a, int lda,
2872                           double beta, DeviceMemory<double> *c, int ldc) {
2873   return DoBlasInternal(cublasDsyrk, stream, true /* = pointer_mode_host */,
2874                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2875                         k, &alpha, GpuMemory(a), lda, &beta,
2876                         GpuMemoryMutable(c), ldc);
2877 }
2878 
2879 bool CUDABlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
2880                           blas::Transpose trans, uint64 n, uint64 k,
2881                           std::complex<float> alpha,
2882                           const DeviceMemory<std::complex<float>> &a, int lda,
2883                           std::complex<float> beta,
2884                           DeviceMemory<std::complex<float>> *c, int ldc) {
2885   auto cb_alpha = GpuComplexValue(alpha);
2886   auto cb_beta = GpuComplexValue(beta);
2887   return DoBlasInternal(cublasCsyrk, stream, true /* = pointer_mode_host */,
2888                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2889                         k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2890                         GpuComplex(&cb_beta), GpuComplex(GpuMemoryMutable(c)),
2891                         ldc);
2892 }
2893 
2894 bool CUDABlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
2895                           blas::Transpose trans, uint64 n, uint64 k,
2896                           std::complex<double> alpha,
2897                           const DeviceMemory<std::complex<double>> &a, int lda,
2898                           std::complex<double> beta,
2899                           DeviceMemory<std::complex<double>> *c, int ldc) {
2900   auto cb_alpha = GpuComplexValue(alpha);
2901   auto cb_beta = GpuComplexValue(beta);
2902   return DoBlasInternal(cublasZsyrk, stream, true /* = pointer_mode_host */,
2903                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2904                         k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2905                         GpuComplex(&cb_beta), GpuComplex(GpuMemoryMutable(c)),
2906                         ldc);
2907 }
2908 
2909 bool CUDABlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
2910                            blas::Transpose trans, uint64 n, uint64 k,
2911                            float alpha, const DeviceMemory<float> &a, int lda,
2912                            const DeviceMemory<float> &b, int ldb, float beta,
2913                            DeviceMemory<float> *c, int ldc) {
2914   return DoBlasInternal(cublasSsyr2k, stream, true /* = pointer_mode_host */,
2915                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2916                         k, &alpha, GpuMemory(a), lda, GpuMemory(b), ldb, &beta,
2917                         GpuMemoryMutable(c), ldc);
2918 }
2919 
2920 bool CUDABlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
2921                            blas::Transpose trans, uint64 n, uint64 k,
2922                            double alpha, const DeviceMemory<double> &a, int lda,
2923                            const DeviceMemory<double> &b, int ldb, double beta,
2924                            DeviceMemory<double> *c, int ldc) {
2925   return DoBlasInternal(cublasDsyr2k, stream, true /* = pointer_mode_host */,
2926                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2927                         k, &alpha, GpuMemory(a), lda, GpuMemory(b), ldb, &beta,
2928                         GpuMemoryMutable(c), ldc);
2929 }
2930 
2931 bool CUDABlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
2932                            blas::Transpose trans, uint64 n, uint64 k,
2933                            std::complex<float> alpha,
2934                            const DeviceMemory<std::complex<float>> &a, int lda,
2935                            const DeviceMemory<std::complex<float>> &b, int ldb,
2936                            std::complex<float> beta,
2937                            DeviceMemory<std::complex<float>> *c, int ldc) {
2938   auto cb_alpha = GpuComplexValue(alpha);
2939   auto cb_beta = GpuComplexValue(beta);
2940   return DoBlasInternal(cublasCsyr2k, stream, true /* = pointer_mode_host */,
2941                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2942                         k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2943                         GpuComplex(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
2944                         GpuComplex(GpuMemoryMutable(c)), ldc);
2945 }
2946 
2947 bool CUDABlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
2948                            blas::Transpose trans, uint64 n, uint64 k,
2949                            std::complex<double> alpha,
2950                            const DeviceMemory<std::complex<double>> &a, int lda,
2951                            const DeviceMemory<std::complex<double>> &b, int ldb,
2952                            std::complex<double> beta,
2953                            DeviceMemory<std::complex<double>> *c, int ldc) {
2954   auto cb_alpha = GpuComplexValue(alpha);
2955   auto cb_beta = GpuComplexValue(beta);
2956   return DoBlasInternal(cublasZsyr2k, stream, true /* = pointer_mode_host */,
2957                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2958                         k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2959                         GpuComplex(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
2960                         GpuComplex(GpuMemoryMutable(c)), ldc);
2961 }
2962 
2963 bool CUDABlas::DoBlasTrmm(Stream *stream, blas::Side side,
2964                           blas::UpperLower uplo, blas::Transpose transa,
2965                           blas::Diagonal diag, uint64 m, uint64 n, float alpha,
2966                           const DeviceMemory<float> &a, int lda,
2967                           DeviceMemory<float> *b, int ldb) {
2968   return DoBlasInternal(cublasStrmm, stream, true /* = pointer_mode_host */,
2969                         CUDABlasSide(side), CUDABlasUpperLower(uplo),
2970                         CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
2971                         &alpha, GpuMemory(a), lda, GpuMemoryMutable(b), ldb,
2972                         GpuMemoryMutable(b), ldb);
2973 }
2974 
2975 bool CUDABlas::DoBlasTrmm(Stream *stream, blas::Side side,
2976                           blas::UpperLower uplo, blas::Transpose transa,
2977                           blas::Diagonal diag, uint64 m, uint64 n, double alpha,
2978                           const DeviceMemory<double> &a, int lda,
2979                           DeviceMemory<double> *b, int ldb) {
2980   return DoBlasInternal(cublasDtrmm, stream, true /* = pointer_mode_host */,
2981                         CUDABlasSide(side), CUDABlasUpperLower(uplo),
2982                         CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
2983                         &alpha, GpuMemory(a), lda, GpuMemoryMutable(b), ldb,
2984                         GpuMemoryMutable(b), ldb);
2985 }
2986 
2987 bool CUDABlas::DoBlasTrmm(Stream *stream, blas::Side side,
2988                           blas::UpperLower uplo, blas::Transpose transa,
2989                           blas::Diagonal diag, uint64 m, uint64 n,
2990                           std::complex<float> alpha,
2991                           const DeviceMemory<std::complex<float>> &a, int lda,
2992                           DeviceMemory<std::complex<float>> *b, int ldb) {
2993   auto cb_alpha = GpuComplexValue(alpha);
2994   return DoBlasInternal(cublasCtrmm, stream, true /* = pointer_mode_host */,
2995                         CUDABlasSide(side), CUDABlasUpperLower(uplo),
2996                         CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
2997                         GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2998                         GpuComplex(GpuMemoryMutable(b)), ldb,
2999                         GpuComplex(GpuMemoryMutable(b)), ldb);
3000 }
3001 
3002 bool CUDABlas::DoBlasTrmm(Stream *stream, blas::Side side,
3003                           blas::UpperLower uplo, blas::Transpose transa,
3004                           blas::Diagonal diag, uint64 m, uint64 n,
3005                           std::complex<double> alpha,
3006                           const DeviceMemory<std::complex<double>> &a, int lda,
3007                           DeviceMemory<std::complex<double>> *b, int ldb) {
3008   auto cb_alpha = GpuComplexValue(alpha);
3009   return DoBlasInternal(cublasZtrmm, stream, true /* = pointer_mode_host */,
3010                         CUDABlasSide(side), CUDABlasUpperLower(uplo),
3011                         CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
3012                         GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
3013                         GpuComplex(GpuMemoryMutable(b)), ldb,
3014                         GpuComplex(GpuMemoryMutable(b)), ldb);
3015 }
3016 
3017 bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
3018                           blas::UpperLower uplo, blas::Transpose transa,
3019                           blas::Diagonal diag, uint64 m, uint64 n, float alpha,
3020                           const DeviceMemory<float> &a, int lda,
3021                           DeviceMemory<float> *b, int ldb) {
3022   return DoBlasInternal(cublasStrsm, stream, true /* = pointer_mode_host */,
3023                         CUDABlasSide(side), CUDABlasUpperLower(uplo),
3024                         CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
3025                         &alpha, GpuMemory(a), lda, GpuMemoryMutable(b), ldb);
3026 }
3027 
3028 bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
3029                           blas::UpperLower uplo, blas::Transpose transa,
3030                           blas::Diagonal diag, uint64 m, uint64 n, double alpha,
3031                           const DeviceMemory<double> &a, int lda,
3032                           DeviceMemory<double> *b, int ldb) {
3033   return DoBlasInternal(cublasDtrsm, stream, true /* = pointer_mode_host */,
3034                         CUDABlasSide(side), CUDABlasUpperLower(uplo),
3035                         CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
3036                         &alpha, GpuMemory(a), lda, GpuMemoryMutable(b), ldb);
3037 }
3038 
3039 bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
3040                           blas::UpperLower uplo, blas::Transpose transa,
3041                           blas::Diagonal diag, uint64 m, uint64 n,
3042                           std::complex<float> alpha,
3043                           const DeviceMemory<std::complex<float>> &a, int lda,
3044                           DeviceMemory<std::complex<float>> *b, int ldb) {
3045   auto cb_alpha = GpuComplexValue(alpha);
3046   return DoBlasInternal(cublasCtrsm, stream, true /* = pointer_mode_host */,
3047                         CUDABlasSide(side), CUDABlasUpperLower(uplo),
3048                         CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
3049                         GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
3050                         GpuComplex(GpuMemoryMutable(b)), ldb);
3051 }
3052 
3053 bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
3054                           blas::UpperLower uplo, blas::Transpose transa,
3055                           blas::Diagonal diag, uint64 m, uint64 n,
3056                           std::complex<double> alpha,
3057                           const DeviceMemory<std::complex<double>> &a, int lda,
3058                           DeviceMemory<std::complex<double>> *b, int ldb) {
3059   auto cb_alpha = GpuComplexValue(alpha);
3060   return DoBlasInternal(cublasZtrsm, stream, true /* = pointer_mode_host */,
3061                         CUDABlasSide(side), CUDABlasUpperLower(uplo),
3062                         CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
3063                         GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
3064                         GpuComplex(GpuMemoryMutable(b)), ldb);
3065 }
3066 
3067 // We only use cublasLt from CUDA 11.0 onward.
3068 #if CUDA_VERSION >= 11000
3069 
3070 namespace {
3071 
3072 template <typename T>
3073 inline port::Status SetCublasLtAttr(cublasLtMatrixLayout_t handle,
3074                                     cublasLtMatrixLayoutAttribute_t attr,
3075                                     const T &value) {
3076   cublasStatus_t status =
3077       cublasLtMatrixLayoutSetAttribute(handle, attr, &value, sizeof(T));
3078   if (status != CUBLAS_STATUS_SUCCESS) {
3079     return port::Status(
3080         port::error::INTERNAL,
3081         absl::StrCat("cublasLtMatrixLayoutSetAttribute(attr=", attr,
3082                      ", value=", value, ") failed: ", ToString(status)));
3083   }
3084   return port::Status::OK();
3085 }
3086 
3087 template <typename T>
3088 inline port::Status SetCublasLtAttr(cublasLtMatmulAlgo_t *handle,
3089                                     cublasLtMatmulAlgoConfigAttributes_t attr,
3090                                     const T &value) {
3091   cublasStatus_t status =
3092       cublasLtMatmulAlgoConfigSetAttribute(handle, attr, &value, sizeof(T));
3093   if (status != CUBLAS_STATUS_SUCCESS) {
3094     return port::Status(
3095         port::error::INTERNAL,
3096         absl::StrCat("cublasLtMatmulAlgoConfigSetAttribute(attr=", attr,
3097                      ", value=", value, ") failed: ", ToString(status)));
3098   }
3099   return port::Status::OK();
3100 }
3101 
3102 template <typename T>
3103 inline port::Status SetCublasLtAttr(cublasLtMatmulPreference_t handle,
3104                                     cublasLtMatmulPreferenceAttributes_t attr,
3105                                     const T &value) {
3106   cublasStatus_t status =
3107       cublasLtMatmulPreferenceSetAttribute(handle, attr, &value, sizeof(value));
3108   if (status != CUBLAS_STATUS_SUCCESS) {
3109     return port::Status(
3110         port::error::INTERNAL,
3111         absl::StrCat("cublasLtMatmulPreferenceSetAttribute(attr=", attr,
3112                      ", value=", value, ") failed: ", ToString(status)));
3113   }
3114   return port::Status::OK();
3115 }
3116 
3117 template <typename T>
3118 inline bool GetCublasLtAttr(const cublasLtMatmulAlgo_t *handle,
3119                             cublasLtMatmulAlgoConfigAttributes_t attr,
3120                             T *value) {
3121   auto mutable_handle = const_cast<cublasLtMatmulAlgo_t *>(handle);
3122   size_t bytes_written = 0;
3123   return cublasLtMatmulAlgoConfigGetAttribute(mutable_handle, attr, value,
3124                                               sizeof(T), &bytes_written) ==
3125              CUBLAS_STATUS_SUCCESS &&
3126          bytes_written == sizeof(T);
3127 }
3128 
3129 template <typename T>
3130 inline const T &ValueForStrCat(const T &value) {
3131   return value;
3132 }
3133 template <typename T>
3134 inline absl::Hex ValueForStrCat(T *ptr) {
3135   return absl::Hex(reinterpret_cast<uintptr_t>(ptr));
3136 }
3137 
3138 template <typename T>
3139 inline port::Status SetCublasLtAttr(cublasLtMatmulDesc_t handle,
3140                                     cublasLtMatmulDescAttributes_t attr,
3141                                     const T &value) {
3142   cublasStatus_t status =
3143       cublasLtMatmulDescSetAttribute(handle, attr, &value, sizeof(value));
3144   if (status != CUBLAS_STATUS_SUCCESS) {
3145     return port::Status(
3146         port::error::INTERNAL,
3147         absl::StrCat("cublasLtMatmulDescSetAttribute(attr=", attr, ", value=",
3148                      ValueForStrCat(value), ") failed: ", ToString(status)));
3149   }
3150   return port::Status::OK();
3151 }
3152 
3153 struct MatmulDescDestroyer {
3154   void operator()(cublasLtMatmulDesc_t matmul_desc) const {
3155     cublasLtMatmulDescDestroy(matmul_desc);
3156   }
3157 };
3158 struct LayoutDestroyer {
3159   void operator()(cublasLtMatrixLayout_t layout) const {
3160     cublasLtMatrixLayoutDestroy(layout);
3161   }
3162 };
3163 struct MatmulPreferenceDestroyer {
3164   void operator()(cublasLtMatmulPreference_t matmul_pref) const {
3165     cublasLtMatmulPreferenceDestroy(matmul_pref);
3166   }
3167 };
3168 using UniqueOpDesc =
3169     std::unique_ptr<std::remove_pointer<cublasLtMatmulDesc_t>::type,
3170                     MatmulDescDestroyer>;
3171 using UniqueLayoutDesc =
3172     std::unique_ptr<std::remove_pointer<cublasLtMatrixLayout_t>::type,
3173                     LayoutDestroyer>;
3174 using UniqueMatmulPreference =
3175     std::unique_ptr<std::remove_pointer<cublasLtMatmulPreference_t>::type,
3176                     MatmulPreferenceDestroyer>;
3177 
3178 port::StatusOr<UniqueOpDesc> CreateCublasLtOperationDesc(
3179     blas::ComputationType computation_type, blas::DataType scale_type,
3180     blas::PointerMode pointer_mode, blas::Epilogue epilogue,
3181     blas::Transpose transa, blas::Transpose transb) {
3182   cublasLtMatmulDesc_t desc;
3183   cublasComputeType_t cublas_compute_type =
3184       CUBLASComputationType(computation_type);
3185   cudaDataType_t cuda_scale_type = GetCUDADataType(scale_type);
3186   cublasStatus_t status =
3187       cublasLtMatmulDescCreate(&desc, cublas_compute_type, cuda_scale_type);
3188   if (status != CUBLAS_STATUS_SUCCESS) {
3189     return port::Status(
3190         port::error::INTERNAL,
3191         absl::StrCat("cublasLtMatmulDescCreate(computation_type=",
3192                      computation_type, ") failed: ", ToString(status)));
3193   }
3194   UniqueOpDesc unique_desc(desc);
3195   SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_POINTER_MODE,
3196                                      CUBLASPointerMode(pointer_mode)));
3197   SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_EPILOGUE,
3198                                      CUBLASEpilogue(epilogue)));
3199   SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSA,
3200                                      CUDABlasTranspose(transa)));
3201   SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSB,
3202                                      CUDABlasTranspose(transb)));
3203   return unique_desc;
3204 }
3205 
3206 port::StatusOr<UniqueLayoutDesc> CreateCublasLtLayoutDesc(
3207     blas::DataType data_type, uint64 rows, uint64 cols, int64 ld, int64 stride,
3208     int batch_count) {
3209   cublasLtMatrixLayout_t desc;
3210   cublasStatus_t status = cublasLtMatrixLayoutCreate(
3211       &desc, GetCUDADataType(data_type), rows, cols, ld);
3212   if (status != CUBLAS_STATUS_SUCCESS) {
3213     return port::Status(
3214         port::error::INTERNAL,
3215         absl::StrCat("cublasLtMatrixLayoutCreate failed: ", ToString(status)));
3216   }
3217   UniqueLayoutDesc unique_desc(desc);
3218   SE_RETURN_IF_ERROR(
3219       SetCublasLtAttr(desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, batch_count));
3220   SE_RETURN_IF_ERROR(SetCublasLtAttr(
3221       desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stride));
3222   return unique_desc;
3223 }
3224 
3225 // Helper function to allocate workspace.
3226 port::Status AllocateWorkspace(void **workspace,
3227                                ScratchAllocator *scratch_allocator,
3228                                size_t num_bytes) {
3229   SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace_bytes,
3230                       scratch_allocator->AllocateBytes(num_bytes));
3231   *workspace = (void *)GpuMemoryMutable(&workspace_bytes);
3232   return port::Status::OK();
3233 }
3234 
3235 template <typename T>
3236 blas::ComputationType ToComputationType();
3237 template <>
3238 blas::ComputationType ToComputationType<Eigen::half>() {
3239   return blas::ComputationType::kF16;
3240 }
3241 template <>
3242 blas::ComputationType ToComputationType<float>() {
3243   return blas::ComputationType::kF32;
3244 }
3245 template <>
3246 blas::ComputationType ToComputationType<double>() {
3247   return blas::ComputationType::kF64;
3248 }
3249 template <>
3250 blas::ComputationType ToComputationType<std::complex<float>>() {
3251   return blas::ComputationType::kComplexF32;
3252 }
3253 template <>
3254 blas::ComputationType ToComputationType<std::complex<double>>() {
3255   return blas::ComputationType::kComplexF64;
3256 }
3257 
3258 class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
3259  public:
3260   port::Status init(const blas::BlasLtMatmulPlanParams &p) {
3261     params_ = p;
3262     scale_type_ = GetScaleType(p.c_type, p.computation_type);
3263     SE_ASSIGN_OR_RETURN(
3264         op_desc_,
3265         CreateCublasLtOperationDesc(
3266             p.computation_type, GetScaleType(p.c_type, p.computation_type),
3267             p.pointer_mode, p.epilogue, p.transa, p.transb));
3268     uint64 rows_a = p.transa == blas::Transpose::kNoTranspose ? p.m : p.k;
3269     uint64 cols_a = p.transa == blas::Transpose::kNoTranspose ? p.k : p.m;
3270     uint64 rows_b = p.transb == blas::Transpose::kNoTranspose ? p.k : p.n;
3271     uint64 cols_b = p.transb == blas::Transpose::kNoTranspose ? p.n : p.k;
3272     SE_ASSIGN_OR_RETURN(
3273         a_desc_, CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda,
3274                                           p.stride_a, capped_batch_count()));
3275     SE_ASSIGN_OR_RETURN(
3276         b_desc_, CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb,
3277                                           p.stride_b, capped_batch_count()));
3278     SE_ASSIGN_OR_RETURN(
3279         c_desc_, CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
3280                                           capped_batch_count()));
3281     SE_ASSIGN_OR_RETURN(
3282         d_desc_, CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
3283                                           capped_batch_count()));
3284     remainder_batch_count_ =
3285         p.batch_count > kMaxBatchCount ? p.batch_count % kMaxBatchCount : 0;
3286     if (remainder_batch_count_) {
3287       SE_ASSIGN_OR_RETURN(
3288           a_remainder_desc_,
3289           CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda, p.stride_a,
3290                                    remainder_batch_count_));
3291       SE_ASSIGN_OR_RETURN(
3292           b_remainder_desc_,
3293           CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb, p.stride_b,
3294                                    remainder_batch_count_));
3295       SE_ASSIGN_OR_RETURN(
3296           c_remainder_desc_,
3297           CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
3298                                    remainder_batch_count_));
3299       SE_ASSIGN_OR_RETURN(
3300           d_remainder_desc_,
3301           CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
3302                                    remainder_batch_count_));
3303     }
3304     return port::Status::OK();
3305   }
3306 
3307   cublasLtMatmulDesc_t op_desc() const { return op_desc_.get(); }
3308   cublasLtMatrixLayout_t a_desc() const { return a_desc_.get(); }
3309   cublasLtMatrixLayout_t b_desc() const { return b_desc_.get(); }
3310   cublasLtMatrixLayout_t c_desc() const { return c_desc_.get(); }
3311   cublasLtMatrixLayout_t d_desc() const { return d_desc_.get(); }
3312   cublasLtMatrixLayout_t a_remainder_desc() const {
3313     return a_remainder_desc_.get();
3314   }
3315   cublasLtMatrixLayout_t b_remainder_desc() const {
3316     return b_remainder_desc_.get();
3317   }
3318   cublasLtMatrixLayout_t c_remainder_desc() const {
3319     return c_remainder_desc_.get();
3320   }
3321   cublasLtMatrixLayout_t d_remainder_desc() const {
3322     return d_remainder_desc_.get();
3323   }
3324 
3325   const blas::BlasLtMatmulPlanParams &params() const { return params_; }
3326   blas::DataType scale_type() const { return scale_type_; }
3327   blas::DataType ab_type() const override { return params_.ab_type; }
3328   blas::DataType c_type() const override { return params_.c_type; }
3329   int capped_batch_count() const {
3330     return std::min(params_.batch_count, kMaxBatchCount);
3331   }
3332   int remainder_batch_count() const { return remainder_batch_count_; }
3333 
3334   // Note: Must be const to satisfy API. This is always called before the plan
3335   // is executed, so the state change is not observed in subsequent executions.
3336   bool SetBiasPointer(const void *bias) const;
3337 
3338  private:
3339   // In some cases cublasLt does not support large batch sizes, so we need to
3340   // split up such cases into multiple calls.
3341   // TODO(reedwm): Making this static or constexpr causes a link error with gcc
3342   // in debug mode for unknown reasons. Investigate why.
3343   const int kMaxBatchCount = 65535;
3344   blas::BlasLtMatmulPlanParams params_;
3345   blas::DataType scale_type_;
3346   UniqueOpDesc op_desc_;
3347   // These have batch count set to capped_batch_count().
3348   UniqueLayoutDesc a_desc_;
3349   UniqueLayoutDesc b_desc_;
3350   UniqueLayoutDesc c_desc_;
3351   UniqueLayoutDesc d_desc_;
3352   int remainder_batch_count_;
3353   // These have batch count set to remainder_batch_count_, and are only created
3354   // if params_.batch_count > kMaxBatchSize.
3355   UniqueLayoutDesc a_remainder_desc_;
3356   UniqueLayoutDesc b_remainder_desc_;
3357   UniqueLayoutDesc c_remainder_desc_;
3358   UniqueLayoutDesc d_remainder_desc_;
3359 };
3360 
3361 bool CUDABlasLtMatmulPlan::SetBiasPointer(const void *bias) const {
3362   return SetCublasLtAttr(op_desc_.get(), CUBLASLT_MATMUL_DESC_BIAS_POINTER,
3363                          bias)
3364       .ok();
3365 }
3366 
3367 class CUDABlasLtMatmulAlgorithm final : public blas::IBlasLtMatmulAlgorithm {
3368  public:
3369   CUDABlasLtMatmulAlgorithm(blas::AlgorithmType index,
3370                             cublasLtMatmulAlgo_t algo, size_t workspace_size)
3371       : index_(index), algo_(algo), workspace_size_(workspace_size) {}
3372 
3373   blas::AlgorithmType index() const override { return index_; }
3374 
3375   size_t workspace_size() const override { return workspace_size_; }
3376 
3377   const cublasLtMatmulAlgo_t *algo() const { return &algo_; }
3378 
3379   int algo_id() const {
3380     int id;
3381     GetCublasLtAttr(&algo_, CUBLASLT_ALGO_CONFIG_ID, &id);
3382     return id;
3383   }
3384 
3385  private:
3386   blas::AlgorithmType index_;
3387   cublasLtMatmulAlgo_t algo_;
3388   size_t workspace_size_;
3389 };
3390 
3391 port::StatusOr<UniqueMatmulPreference> CreateCublasLtMatmulPreference(
3392     const blas::IBlasLtMatmulPlan *plan, size_t max_workspace_bytes) {
3393   cublasLtMatmulPreference_t preference;
3394   cublasStatus_t status = cublasLtMatmulPreferenceCreate(&preference);
3395   if (status != CUBLAS_STATUS_SUCCESS) {
3396     return port::Status(port::error::INTERNAL,
3397                         absl::StrCat("cublasLtMatmulPreferenceCreate failed: ",
3398                                      ToString(status)));
3399   }
3400   UniqueMatmulPreference unique_preference(preference);
3401   SE_RETURN_IF_ERROR(SetCublasLtAttr(preference,
3402                                      CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
3403                                      max_workspace_bytes));
3404 
3405   const auto &cuda_plan = *static_cast<const CUDABlasLtMatmulPlan *>(plan);
3406   if (cuda_plan.params().batch_count == 0) {
3407     return unique_preference;
3408   }
3409   // This is a workaround for a known issue in cuBlasLt where the heuristic may
3410   // in rare cases select an algo that does not support the specified stride.
3411   // Specifying the alignment requirements manually like this avoids the issue.
3412   auto get_alignment_bytes = [](int64 stride, blas::DataType dtype) {
3413     return (stride & -stride) * GetDataTypeSizeBytes(dtype);
3414   };
3415   if (cuda_plan.params().stride_a) {
3416     SE_RETURN_IF_ERROR(SetCublasLtAttr(
3417         preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES,
3418         (uint32)get_alignment_bytes(cuda_plan.params().stride_a,
3419                                     cuda_plan.params().ab_type)));
3420   }
3421   if (cuda_plan.params().stride_b) {
3422     SE_RETURN_IF_ERROR(SetCublasLtAttr(
3423         preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES,
3424         (uint32)get_alignment_bytes(cuda_plan.params().stride_b,
3425                                     cuda_plan.params().ab_type)));
3426   }
3427   if (cuda_plan.params().stride_c) {
3428     SE_RETURN_IF_ERROR(SetCublasLtAttr(
3429         preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES,
3430         (uint32)get_alignment_bytes(cuda_plan.params().stride_c,
3431                                     cuda_plan.params().c_type)));
3432   }
3433   if (cuda_plan.params().stride_c) {
3434     SE_RETURN_IF_ERROR(SetCublasLtAttr(
3435         preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES,
3436         (uint32)get_alignment_bytes(cuda_plan.params().stride_c,
3437                                     cuda_plan.params().c_type)));
3438   }
3439   return unique_preference;
3440 }
3441 
3442 }  // namespace
3443 
3444 #endif  // CUDA_VERSION >= 11000
3445 
3446 port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
3447 CUDABlas::CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams &p) {
3448 #if CUDA_VERSION >= 11000
3449   auto cuda_plan = std::make_unique<CUDABlasLtMatmulPlan>();
3450   SE_RETURN_IF_ERROR(cuda_plan->init(p));
3451   return static_cast<std::unique_ptr<blas::IBlasLtMatmulPlan>>(
3452       std::move(cuda_plan));
3453 #else
3454   return port::Status(
3455       port::error::UNIMPLEMENTED,
3456       "CreateBlasLtMatmulPlan is not supported with this version of CUDA");
3457 #endif
3458 }
3459 
3460 port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
3461 CUDABlas::GetBlasLtMatmulAlgorithmsInternal(const blas::IBlasLtMatmulPlan *plan,
3462                                             size_t max_workspace_size,
3463                                             int max_algorithm_count,
3464                                             bool for_remainder_batch) {
3465 #if CUDA_VERSION >= 11000
3466   SE_ASSIGN_OR_RETURN(UniqueMatmulPreference preference,
3467                       CreateCublasLtMatmulPreference(plan, max_workspace_size));
3468 
3469   std::vector<cublasLtMatmulHeuristicResult_t> results(max_algorithm_count);
3470   {
3471     absl::MutexLock lock(&mu_);
3472 
3473     CHECK(blasLt_ != nullptr);
3474 
3475     gpu::ScopedActivateExecutorContext sac{parent_};
3476 
3477     int found_algorithm_count = 0;
3478     const auto &cuda_plan = *static_cast<const CUDABlasLtMatmulPlan *>(plan);
3479     const auto &a_desc =
3480         for_remainder_batch ? cuda_plan.a_remainder_desc() : cuda_plan.a_desc();
3481     const auto &b_desc =
3482         for_remainder_batch ? cuda_plan.b_remainder_desc() : cuda_plan.b_desc();
3483     const auto &c_desc =
3484         for_remainder_batch ? cuda_plan.c_remainder_desc() : cuda_plan.c_desc();
3485     const auto &d_desc =
3486         for_remainder_batch ? cuda_plan.d_remainder_desc() : cuda_plan.d_desc();
3487     cublasStatus_t status = cublasLtMatmulAlgoGetHeuristic(
3488         blasLt_, cuda_plan.op_desc(), a_desc, b_desc, c_desc, d_desc,
3489         preference.get(), max_algorithm_count, results.data(),
3490         &found_algorithm_count);
3491     if (status != CUBLAS_STATUS_SUCCESS) {
3492       return port::Status(
3493           port::error::INTERNAL,
3494           absl::StrCat("cublasLtMatmulAlgoGetHeuristic failed: ",
3495                        ToString(status)));
3496     }
3497     results.resize(found_algorithm_count);
3498   }
3499 
3500   std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>> out_algorithms;
3501   out_algorithms.reserve(results.size());
3502   for (size_t i = 0; i < results.size(); ++i) {
3503     const auto &result = results[i];
3504     if (result.state != CUBLAS_STATUS_SUCCESS) continue;  // Skip failed algos
3505     out_algorithms.emplace_back(std::make_unique<CUDABlasLtMatmulAlgorithm>(
3506         i, result.algo, result.workspaceSize));
3507   }
3508   return out_algorithms;
3509 #else  // if CUDA_VERSION < 11000
3510   return port::Status(
3511       port::error::UNIMPLEMENTED,
3512       "GetBlasLtMatmulAlgorithms is not supported with this version of CUDA");
3513 #endif
3514 }
3515 
3516 port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
3517 CUDABlas::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
3518                                     size_t max_workspace_size,
3519                                     int max_algorithm_count) {
3520   return GetBlasLtMatmulAlgorithmsInternal(plan, max_workspace_size,
3521                                            max_algorithm_count);
3522 }
3523 
3524 #if CUDA_VERSION >= 11000
3525 bool CUDABlas::DoBlasLtMatmulInternal(
3526     Stream *stream, bool err_on_failure, const blas::IBlasLtMatmulPlan *plan,
3527     const HostOrDeviceScalar<void> &alpha, DeviceMemoryBase a,
3528     DeviceMemoryBase b, const HostOrDeviceScalar<void> &beta,
3529     DeviceMemoryBase c, DeviceMemoryBase d, ScratchAllocator *scratch_allocator,
3530     const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias) {
3531   const auto &cuda_plan = *static_cast<const CUDABlasLtMatmulPlan *>(plan);
3532   const auto &cuda_algo =
3533       *static_cast<const CUDABlasLtMatmulAlgorithm *>(algorithm);
3534 
3535   if (alpha.data_type() != cuda_plan.scale_type() ||
3536       beta.data_type() != cuda_plan.scale_type()) {
3537     VLOG(2) << "DoBlasLtMatmul returning false because alpha and beta types do "
3538                "not match plan: expected "
3539             << cuda_plan.c_type() << ", got alpha=" << alpha.data_type()
3540             << " beta=" << beta.data_type();
3541     return false;
3542   }
3543   if (alpha.is_pointer() != beta.is_pointer()) {
3544     VLOG(2) << "DoBlasLtMatmul returning false because one of `alpha` "
3545                "and `beta` is a pointer, but the other is not.";
3546     return false;
3547   }
3548   bool is_pointer_mode_host = !alpha.is_pointer();
3549   if ((cuda_plan.params().pointer_mode == blas::PointerMode::kHost) !=
3550       is_pointer_mode_host) {
3551     VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong "
3552                "pointer_mode for the given alpha/beta.";
3553     return false;
3554   }
3555   if ((cuda_plan.params().epilogue == blas::Epilogue::kBias ||
3556        cuda_plan.params().epilogue == blas::Epilogue::kBiasThenReLU) !=
3557       (bias != nullptr)) {
3558     VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong "
3559                "epilogue for the given bias pointer.";
3560     return false;
3561   }
3562   const void *alpha_ptr = alpha.is_pointer() ? alpha.opaque_pointer().opaque()
3563                                              : alpha.opaque_value();
3564   const void *beta_ptr =
3565       beta.is_pointer() ? beta.opaque_pointer().opaque() : beta.opaque_value();
3566 
3567   void *workspace = nullptr;
3568   if (cuda_algo.workspace_size()) {
3569     port::Status allocation_status = AllocateWorkspace(
3570         &workspace, scratch_allocator, cuda_algo.workspace_size());
3571     if (!allocation_status.ok()) {
3572       if (err_on_failure || VLOG_IS_ON(3)) {
3573         LOG(ERROR)
3574             << "Failed to allocate workspace for cublasLtMatmul algo with id: "
3575             << cuda_algo.algo_id() << " requiring "
3576             << cuda_algo.workspace_size() << " bytes of workspace";
3577       }
3578       return false;
3579     }
3580   }
3581 
3582   // This is only used when batch_count > kMaxBatchCount.
3583   std::unique_ptr<blas::IBlasLtMatmulAlgorithm> unique_remainder_algo;
3584   if (cuda_plan.remainder_batch_count()) {
3585     // There is no easy way to get the user-specified max workspace size here,
3586     // so we just allow a very small amount and don't worry too much about
3587     // performance because this is only used in rare cases. The same reasoning
3588     // applies to selection of the algorithm.
3589     size_t max_workspace_size = 4 * 1024 * 1024;  // 4 MiB
3590     auto status_or_algorithms =
3591         GetBlasLtMatmulAlgorithmsInternal(plan, max_workspace_size,
3592                                           /* max_algorithm_count = */ 1,
3593                                           /* for_remainder_batch = */ true);
3594     if (!status_or_algorithms.ok()) {
3595       if (err_on_failure || VLOG_IS_ON(3)) {
3596         LOG(ERROR) << "Failed to get algorithms for blasLt remainder batch.";
3597       }
3598       return false;
3599     }
3600     auto algorithms = status_or_algorithms.ConsumeValueOrDie();
3601     unique_remainder_algo = std::move(algorithms.front());
3602   }
3603 
3604   cudaStream_t cuda_stream = CUDAStream(stream);
3605 
3606   absl::MutexLock lock(&mu_);
3607 
3608   if (bias != nullptr) {
3609     if (!cuda_plan.SetBiasPointer(bias.opaque())) {
3610       VLOG(2) << "DoBlasLtMatmul returning false because setting the bias "
3611                  "pointer failed.";
3612       return false;
3613     }
3614   }
3615 
3616   CHECK(blasLt_ != nullptr);
3617 
3618   gpu::ScopedActivateExecutorContext sac{parent_};
3619 
3620   // Plan execution is broken down into repeat calls with capped_batch_count,
3621   // followed by a final call with remainder_batch_count.
3622   // Cases where batch_count <= kMaxBatchCount require only a single call (a
3623   // single loop iteration and no remainder).
3624   int ab_type_size = GetDataTypeSizeBytes(cuda_plan.params().ab_type);
3625   int c_type_size = GetDataTypeSizeBytes(cuda_plan.params().c_type);
3626   const char *a_ptr = static_cast<const char *>(a.opaque());
3627   const char *b_ptr = static_cast<const char *>(b.opaque());
3628   const char *c_ptr = static_cast<const char *>(c.opaque());
3629   char *d_ptr = static_cast<char *>(d.opaque());
3630   int capped_batch_count = cuda_plan.capped_batch_count();
3631   for (int batch = 0;
3632        batch + capped_batch_count <= cuda_plan.params().batch_count;
3633        batch += capped_batch_count) {
3634     cublasStatus_t ret = cublasLtMatmul(
3635         blasLt_, cuda_plan.op_desc(), alpha_ptr, a_ptr, cuda_plan.a_desc(),
3636         b_ptr, cuda_plan.b_desc(), beta_ptr, c_ptr, cuda_plan.c_desc(), d_ptr,
3637         cuda_plan.d_desc(), cuda_algo.algo(), workspace,
3638         cuda_algo.workspace_size(), cuda_stream);
3639     if (ret != CUBLAS_STATUS_SUCCESS) {
3640       if (err_on_failure || VLOG_IS_ON(3)) {
3641         LOG(ERROR) << "failed to run cublasLtMatmul routine: " << ToString(ret);
3642       }
3643       return false;
3644     }
3645     a_ptr += capped_batch_count * cuda_plan.params().stride_a * ab_type_size;
3646     b_ptr += capped_batch_count * cuda_plan.params().stride_b * ab_type_size;
3647     c_ptr += capped_batch_count * cuda_plan.params().stride_c * c_type_size;
3648     d_ptr += capped_batch_count * cuda_plan.params().stride_c * c_type_size;
3649   }
3650   // This is only used when batch_count > kMaxBatchCount.
3651   if (cuda_plan.remainder_batch_count()) {
3652     const auto &remainder_algo =
3653         *static_cast<const CUDABlasLtMatmulAlgorithm *>(
3654             unique_remainder_algo.get());
3655     if (remainder_algo.workspace_size()) {
3656       port::Status allocation_status = AllocateWorkspace(
3657           &workspace, scratch_allocator, remainder_algo.workspace_size());
3658       if (!allocation_status.ok()) {
3659         if (err_on_failure || VLOG_IS_ON(3)) {
3660           LOG(ERROR) << "Failed to allocate workspace for cublasLtMatmul algo "
3661                         "with id: "
3662                      << remainder_algo.algo_id() << " requiring "
3663                      << remainder_algo.workspace_size()
3664                      << " bytes of workspace";
3665         }
3666         return false;
3667       }
3668     }
3669     cublasStatus_t ret = cublasLtMatmul(
3670         blasLt_, cuda_plan.op_desc(), alpha_ptr, a_ptr,
3671         cuda_plan.a_remainder_desc(), b_ptr, cuda_plan.b_remainder_desc(),
3672         beta_ptr, c_ptr, cuda_plan.c_remainder_desc(), d_ptr,
3673         cuda_plan.d_remainder_desc(), remainder_algo.algo(), workspace,
3674         remainder_algo.workspace_size(), cuda_stream);
3675     if (ret != CUBLAS_STATUS_SUCCESS) {
3676       if (err_on_failure || VLOG_IS_ON(3)) {
3677         LOG(ERROR) << "failed to run remainder cublasLtMatmul routine: "
3678                    << ToString(ret);
3679       }
3680       return false;
3681     }
3682   }
3683   return true;
3684 }
3685 #endif  // CUDA_VERSION >= 11000
3686 
3687 bool CUDABlas::DoBlasLtMatmul(
3688     Stream *stream, const blas::IBlasLtMatmulPlan *plan,
3689     const HostOrDeviceScalar<void> &alpha, DeviceMemoryBase a,
3690     DeviceMemoryBase b, const HostOrDeviceScalar<void> &beta,
3691     DeviceMemoryBase c, ScratchAllocator *scratch_allocator,
3692     const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias,
3693     blas::ProfileResult *output_profile_result) {
3694 #if CUDA_VERSION >= 11000
3695   const auto &cuda_plan = *static_cast<const CUDABlasLtMatmulPlan *>(plan);
3696   HostOrDeviceScalar<void> alpha_cast = alpha;
3697   HostOrDeviceScalar<void> beta_cast = beta;
3698   if (cuda_plan.c_type() == blas::DataType::kHalf &&
3699       cuda_plan.scale_type() == blas::DataType::kFloat) {
3700     // The given alpha and beta types are F16 (they always match c), but F32*
3701     // computation type requires that they be F32, so we must cast them.
3702     if (alpha.is_pointer() || beta.is_pointer()) {
3703       // We cannot easily convert a pointer to f16 memory to a pointer to f32
3704       // memory from here, so we don't support this for now.
3705       return false;
3706     }
3707     alpha_cast = HostOrDeviceScalar<void>(
3708         static_cast<float>(alpha.value<Eigen::half>()));
3709     beta_cast =
3710         HostOrDeviceScalar<void>(static_cast<float>(beta.value<Eigen::half>()));
3711   }
3712 
3713   std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
3714   if (output_profile_result) {
3715     timer.reset(new GpuTimer(parent_));
3716     if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
3717       return false;
3718     }
3719   }
3720 
3721   bool err_on_failure = timer != nullptr;
3722   bool result = DoBlasLtMatmulInternal(stream, err_on_failure, plan, alpha_cast,
3723                                        a, b, beta_cast, c, c, scratch_allocator,
3724                                        algorithm, bias);
3725 
3726   if (timer && result) {
3727     // GpuTimer will CHECK-fail if we Stop() it while the stream is in an error
3728     // state.
3729     if (!timer->Stop(AsGpuStream(stream))) {
3730       return false;
3731     }
3732     output_profile_result->set_is_valid(true);
3733     output_profile_result->set_algorithm(algorithm->index());
3734     output_profile_result->set_elapsed_time_in_ms(
3735         timer->GetElapsedMilliseconds());
3736   }
3737   return result;
3738 #else  // if CUDA_VERSION < 11000
3739   return false;
3740 #endif
3741 }
3742 
3743 port::Status CUDABlas::GetVersion(std::string *version) {
3744   absl::MutexLock lock(&mu_);
3745 
3746   int v;
3747   auto status = cublasGetVersion(blas_, &v);
3748   if (status != CUBLAS_STATUS_SUCCESS) {
3749     return port::InternalError(ToString(status));
3750   }
3751   *version = std::to_string(v);
3752   return port::Status::OK();
3753 }
3754 
3755 }  // namespace gpu
3756 
initialize_cublas()3757 void initialize_cublas() {
3758   port::Status status =
3759       PluginRegistry::Instance()->RegisterFactory<PluginRegistry::BlasFactory>(
3760           cuda::kCudaPlatformId, gpu::kCuBlasPlugin, "cuBLAS",
3761           [](internal::StreamExecutorInterface *parent) -> blas::BlasSupport * {
3762             gpu::GpuExecutor *cuda_executor =
3763                 dynamic_cast<gpu::GpuExecutor *>(parent);
3764             if (cuda_executor == nullptr) {
3765               LOG(ERROR)
3766                   << "Attempting to initialize an instance of the cuBLAS "
3767                   << "support library with a non-CUDA StreamExecutor";
3768               return nullptr;
3769             }
3770 
3771             gpu::CUDABlas *blas = new gpu::CUDABlas(cuda_executor);
3772             if (!blas->Init()) {
3773               // Note: Init() will log a more specific error.
3774               delete blas;
3775               return nullptr;
3776             }
3777             return blas;
3778           });
3779 
3780   if (!status.ok()) {
3781     LOG(ERROR) << "Unable to register cuBLAS factory: "
3782                << status.error_message();
3783   }
3784 
3785   PluginRegistry::Instance()->SetDefaultFactory(
3786       cuda::kCudaPlatformId, PluginKind::kBlas, gpu::kCuBlasPlugin);
3787 }
3788 
3789 }  // namespace stream_executor
3790 
3791 REGISTER_MODULE_INITIALIZER(register_cublas,
3792                             { stream_executor::initialize_cublas(); });
3793