1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "third_party/gpus/cuda/include/cublasLt.h"
17 #include "third_party/gpus/cuda/include/cublas_v2.h"
18 #include "third_party/gpus/cuda/include/cuda.h"
19
20 #define SE_CUDA_DATA_HALF CUDA_R_16F
21
22 #include "tensorflow/stream_executor/cuda/cuda_blas.h"
23
24 // Both Eigen Half.h and CUDA cuda_fp16.h provide similar typedef for __half. As
25 // such, there are two ways to get the typedef for __half:
26 //
27 // (1) Includes cuda_fp16.h and defines EIGEN_HAS_CUDA_FP16.
28 // (2) Neither includes cuda_fp16.h nor defines EIGEN_HAS_CUDA_FP16.
29 //
30 // Due to issue b/73793421, when the first approach is used and NVCC is used to
31 // compile this file, NVCC will complain duplicated definition for
32 // EIGEN_HAS_CUDA_FP16. On the other hand, when the second approach is used and
33 // clang is used to compile this file, clang will not understand __half
34 // due to missing the definition and macro EIGEN_HAS_CUDA_FP16.
35 //
36 // Because this file may be compiled with CLANG but will never be compiled with
37 // NVCC, we choose the first approach for CUDA < 9.0. For CUDA >= 9.0, we have
38 // to use the second approach because the data member in the __half defined
39 // by CUDA > 9.0 is `__x` while Eigen expects it to be `x`.
40 //
41 // TODO(b/73793421): Remove the following code block to switch to the second
42 // approach when the issue is fixed.
43 #if CUDA_VERSION < 9000
44 #include "third_party/gpus/cuda/include/cuda_fp16.h"
45 #define EIGEN_HAS_CUDA_FP16
46 #endif
47
48 #include <complex>
49
50 #include "absl/strings/str_cat.h"
51 #include "absl/strings/str_format.h"
52 #include "third_party/eigen3/Eigen/Core"
53 #include "tensorflow/core/platform/tensor_float_32_utils.h"
54 #include "tensorflow/core/util/env_var.h"
55 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
56 #include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
57 #include "tensorflow/stream_executor/cuda/cuda_helpers.h"
58 #include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
59 #include "tensorflow/stream_executor/cuda/cuda_stream.h"
60 #include "tensorflow/stream_executor/cuda/cuda_timer.h"
61 #include "tensorflow/stream_executor/device_memory.h"
62 #include "tensorflow/stream_executor/lib/env.h"
63 #include "tensorflow/stream_executor/lib/initialize.h"
64 #include "tensorflow/stream_executor/lib/status.h"
65 #include "tensorflow/stream_executor/lib/status_macros.h"
66 #include "tensorflow/stream_executor/platform/logging.h"
67 #include "tensorflow/stream_executor/platform/port.h"
68 #include "tensorflow/stream_executor/plugin_registry.h"
69 #include "tensorflow/stream_executor/scratch_allocator.h"
70 #include "tensorflow/stream_executor/stream_executor.h"
71
72 namespace stream_executor {
73 namespace gpu {
74
75 PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuBlasPlugin);
76
ToString(cublasStatus_t status)77 static std::string ToString(cublasStatus_t status) {
78 switch (status) {
79 case CUBLAS_STATUS_SUCCESS:
80 return "CUBLAS_STATUS_SUCCESS";
81 case CUBLAS_STATUS_NOT_INITIALIZED:
82 return "CUBLAS_STATUS_NOT_INITIALIZED";
83 case CUBLAS_STATUS_ALLOC_FAILED:
84 return "CUBLAS_STATUS_ALLOC_FAILED";
85 case CUBLAS_STATUS_INVALID_VALUE:
86 return "CUBLAS_STATUS_INVALID_VALUE";
87 case CUBLAS_STATUS_ARCH_MISMATCH:
88 return "CUBLAS_STATUS_ARCH_MISMATCH";
89 case CUBLAS_STATUS_MAPPING_ERROR:
90 return "CUBLAS_STATUS_MAPPING_ERROR";
91 case CUBLAS_STATUS_EXECUTION_FAILED:
92 return "CUBLAS_STATUS_EXECUTION_FAILED";
93 case CUBLAS_STATUS_INTERNAL_ERROR:
94 return "CUBLAS_STATUS_INTERNAL_ERROR";
95 #if CUDA_VERSION >= 8000
96 case CUBLAS_STATUS_NOT_SUPPORTED:
97 return "CUBLAS_STATUS_NOT_SUPPORTED";
98 case CUBLAS_STATUS_LICENSE_ERROR:
99 return "CUBLAS_STATUS_LICENSE_ERROR";
100 #endif
101 default:
102 return absl::StrCat("<invalid cublas status: ", status, ">");
103 }
104 }
105
106 // cuBLAS has interfaces that permit pointers to be passed from either the host
107 // memory space or the device memory space; however, you must instruct it as to
108 // which address space those pointers are in with cublasSetPointerMode.
109 //
110 // This helper sets the cuBLAS pointer mode to a desired value for a cuBLAS call
111 // you are about to perform in a given scope.
112 //
113 // The prior cuBLAS pointer mode is retained and restored when this object goes
114 // out of scope.
115 class ScopedCublasPointerMode {
116 public:
117 // Note that, because the setting of the cublas pointer mode is fallible,
118 // construction of this scoped datatype must be paired with a call to
119 // Init().
120 //
121 // Parameters:
122 // handle: The cublas library handle to act upon in setting the pointer mode.
ScopedCublasPointerMode(cublasHandle_t handle)123 explicit ScopedCublasPointerMode(cublasHandle_t handle)
124 : handle_(handle), ok_(false) {}
125
126 // Attempts the switch to the requested scoped pointer mode, new_mode.
127 //
128 // Note that when false is returned, an appropriate error has already been
129 // logged.
Init(cublasPointerMode_t new_mode)130 bool Init(cublasPointerMode_t new_mode) {
131 cublasStatus_t ret = cublasGetPointerMode(handle_, &old_mode_);
132 if (ret != CUBLAS_STATUS_SUCCESS) {
133 LOG(ERROR) << "failed to get old cublas pointer mode: " << ToString(ret);
134 return ok_ = false;
135 }
136
137 ret = cublasSetPointerMode(handle_, new_mode);
138 if (ret != CUBLAS_STATUS_SUCCESS) {
139 LOG(ERROR) << "failed to set new cublas pointer mode: " << ToString(ret);
140 return ok_ = false;
141 }
142
143 return ok_ = true;
144 }
145
146 // Switches back to the prior pointer mode, if the switch operation was
147 // successful in the first place.
~ScopedCublasPointerMode()148 ~ScopedCublasPointerMode() {
149 if (ok_) {
150 cublasStatus_t ret = cublasSetPointerMode(handle_, old_mode_);
151 if (ret != CUBLAS_STATUS_SUCCESS) {
152 LOG(ERROR) << "failed to set former cublas pointer mode: "
153 << ToString(ret);
154 }
155 }
156 }
157
158 private:
159 cublasHandle_t handle_; // Handle to the cuBLAS instance of interest.
160 cublasPointerMode_t old_mode_; // Prior cuBLAS pointer mode, to be restored.
161 bool ok_; // Whether the change was successful.
162 };
163
164 #if CUDA_VERSION >= 9000
165 // cuBLAS has interfaces that permit computations to use the Volta hardware.
166 // This must be enabled via the cublasGet/SetMathMode APIs.
167 //
168 // This helper sets the cuBLAS math mode to a desired value for a cuBLAS call
169 // you are about to perform in a given scope.
170 //
171 // The prior cuBLAS math mode is retained and restored when this object goes
172 // out of scope.
173 class ScopedCublasMathMode {
174 public:
175 // Note that, because the setting of the cublas math mode is fallible,
176 // construction of this scoped datatype must be paired with a call to
177 // Init().
178 //
179 // Parameters:
180 // handle: The cublas library handle to act upon in setting the math mode.
ScopedCublasMathMode(cublasHandle_t handle)181 explicit ScopedCublasMathMode(cublasHandle_t handle)
182 : handle_(handle), ok_(false) {}
183
184 // Attempts the switch to the requested scoped math mode, new_mode.
185 //
186 // Note that when false is returned, an appropriate error has already been
187 // logged.
Init(cublasMath_t new_mode)188 bool Init(cublasMath_t new_mode) {
189 cublasStatus_t ret = cublasGetMathMode(handle_, &old_mode_);
190 if (ret != CUBLAS_STATUS_SUCCESS) {
191 LOG(ERROR) << "failed to get old cublas math mode: " << ToString(ret);
192 return ok_ = false;
193 }
194
195 ret = cublasSetMathMode(handle_, new_mode);
196 if (ret != CUBLAS_STATUS_SUCCESS) {
197 LOG(ERROR) << "failed to set new cublas math mode: " << ToString(ret);
198 return ok_ = false;
199 }
200 return ok_ = true;
201 }
202
203 // Switches back to the prior math mode, if the switch operation was
204 // successful in the first place.
~ScopedCublasMathMode()205 ~ScopedCublasMathMode() {
206 if (ok_) {
207 cublasStatus_t ret = cublasSetMathMode(handle_, old_mode_);
208 if (ret != CUBLAS_STATUS_SUCCESS) {
209 LOG(ERROR) << "failed to set former cublas math mode: "
210 << ToString(ret);
211 }
212 }
213 }
214
215 private:
216 cublasHandle_t handle_; // Handle to the cuBLAS instance of interest.
217 cublasMath_t old_mode_; // Prior cuBLAS math mode, to be restored.
218 bool ok_; // Whether the change was successful.
219 };
220 #endif // CUDA_VERSION >= 9000
221
Init()222 bool CUDABlas::Init() {
223 gpu::ScopedActivateExecutorContext sac{parent_};
224 cublasStatus_t ret = cublasCreate(&blas_);
225 if (ret != CUBLAS_STATUS_SUCCESS) {
226 LOG(ERROR) << "failed to create cublas handle: " << ToString(ret);
227 return false;
228 }
229
230 #if CUDA_VERSION >= 11000
231 ret = cublasLtCreate(&blasLt_);
232 if (ret != CUBLAS_STATUS_SUCCESS) {
233 LOG(ERROR) << "failed to create cublasLt handle: " << ToString(ret);
234 return false;
235 }
236 #endif // CUDA_VERSION >= 11000
237
238 return true;
239 }
240
CUDABlas(gpu::GpuExecutor * parent)241 CUDABlas::CUDABlas(gpu::GpuExecutor *parent)
242 : parent_(CHECK_NOTNULL(parent)),
243 blas_(nullptr)
244 #if CUDA_VERSION >= 11000
245 ,
246 blasLt_(nullptr)
247 #endif
248 {
249 }
250
~CUDABlas()251 CUDABlas::~CUDABlas() {
252 if (blas_ != nullptr) {
253 gpu::ScopedActivateExecutorContext sac{parent_};
254 cublasDestroy(blas_);
255 }
256 #if CUDA_VERSION >= 11000
257 if (blasLt_ != nullptr) {
258 gpu::ScopedActivateExecutorContext sac{parent_};
259 cublasLtDestroy(blasLt_);
260 }
261 #endif
262 }
263
SetStream(Stream * stream)264 bool CUDABlas::SetStream(Stream *stream) {
265 CHECK(stream != nullptr);
266 CHECK(AsGpuStreamValue(stream) != nullptr);
267 CHECK(blas_ != nullptr);
268 gpu::ScopedActivateExecutorContext sac{parent_};
269 cublasStatus_t ret = cublasSetStream(blas_, AsGpuStreamValue(stream));
270 if (ret != CUBLAS_STATUS_SUCCESS) {
271 LOG(ERROR) << "failed to set stream for cuBLAS calls: " << ToString(ret);
272 return false;
273 }
274
275 return true;
276 }
277
CUDAStream(Stream * stream)278 cudaStream_t CUDABlas::CUDAStream(Stream *stream) {
279 CHECK(stream != nullptr);
280 CHECK(AsGpuStreamValue(stream) != nullptr);
281 gpu::ScopedActivateExecutorContext sac{parent_};
282 return AsGpuStreamValue(stream);
283 }
284
285 namespace {
286
287 // Helper functions transforming blas arguments into cuBLAS arguments.
288
CUDABlasTranspose(blas::Transpose trans)289 cublasOperation_t CUDABlasTranspose(blas::Transpose trans) {
290 switch (trans) {
291 case blas::Transpose::kNoTranspose:
292 return CUBLAS_OP_N;
293 case blas::Transpose::kTranspose:
294 return CUBLAS_OP_T;
295 case blas::Transpose::kConjugateTranspose:
296 return CUBLAS_OP_C;
297 default:
298 LOG(FATAL) << "Invalid value of blas::Transpose.";
299 }
300 }
301
CUDABlasUpperLower(blas::UpperLower uplo)302 cublasFillMode_t CUDABlasUpperLower(blas::UpperLower uplo) {
303 switch (uplo) {
304 case blas::UpperLower::kUpper:
305 return CUBLAS_FILL_MODE_UPPER;
306 case blas::UpperLower::kLower:
307 return CUBLAS_FILL_MODE_LOWER;
308 default:
309 LOG(FATAL) << "Invalid value of blas::UpperLower.";
310 }
311 }
312
CUDABlasDiagonal(blas::Diagonal diag)313 cublasDiagType_t CUDABlasDiagonal(blas::Diagonal diag) {
314 switch (diag) {
315 case blas::Diagonal::kUnit:
316 return CUBLAS_DIAG_UNIT;
317 case blas::Diagonal::kNonUnit:
318 return CUBLAS_DIAG_NON_UNIT;
319 default:
320 LOG(FATAL) << "Invalid value of blas::Diagonal.";
321 }
322 }
323
CUDABlasSide(blas::Side side)324 cublasSideMode_t CUDABlasSide(blas::Side side) {
325 switch (side) {
326 case blas::Side::kLeft:
327 return CUBLAS_SIDE_LEFT;
328 case blas::Side::kRight:
329 return CUBLAS_SIDE_RIGHT;
330 default:
331 LOG(FATAL) << "Invalid value of blas::Side.";
332 }
333 }
334
335 // CUDADataType<T>::type translates from a C++ type (e.g. float) to a
336 // cudaDataType_t (e.g. CUDA_R_32F). CUDAComputationType(ty) translates from a
337 // blas::ComputationType to a cudaDataType_t.
338 //
339 // These are used to build the argument type and computation type args to
340 // cublasGemmEx.
341 template <typename T>
342 struct CUDADataType;
343
344 template <>
345 struct CUDADataType<Eigen::half> {
346 static constexpr cudaDataType_t type = SE_CUDA_DATA_HALF;
347 };
348
349 template <>
350 struct CUDADataType<std::complex<Eigen::half>> {
351 static constexpr cudaDataType_t type = CUDA_C_16F;
352 };
353
354 template <>
355 struct CUDADataType<float> {
356 static constexpr cudaDataType_t type = CUDA_R_32F;
357 };
358
359 template <>
360 struct CUDADataType<std::complex<float>> {
361 static constexpr cudaDataType_t type = CUDA_C_32F;
362 };
363
364 template <>
365 struct CUDADataType<double> {
366 static constexpr cudaDataType_t type = CUDA_R_64F;
367 };
368
369 template <>
370 struct CUDADataType<std::complex<double>> {
371 static constexpr cudaDataType_t type = CUDA_C_64F;
372 };
373
374 template <>
375 struct CUDADataType<int> {
376 static constexpr cudaDataType_t type = CUDA_R_32I;
377 };
378
379 template <>
380 struct CUDADataType<int8> {
381 static constexpr cudaDataType_t type = CUDA_R_8I;
382 };
383
384 template <>
385 struct CUDADataType<std::complex<int8>> {
386 static constexpr cudaDataType_t type = CUDA_C_8I;
387 };
388
389 template <>
390 struct CUDADataType<uint8> {
391 static constexpr cudaDataType_t type = CUDA_R_8U;
392 };
393
394 template <>
395 struct CUDADataType<std::complex<uint8>> {
396 static constexpr cudaDataType_t type = CUDA_C_8U;
397 };
398
CUDAComputationType(blas::ComputationType ty)399 cudaDataType_t CUDAComputationType(blas::ComputationType ty) {
400 switch (ty) {
401 case blas::ComputationType::kF16:
402 return CUDA_R_16F;
403 case blas::ComputationType::kF32:
404 return CUDA_R_32F;
405 case blas::ComputationType::kF64:
406 return CUDA_R_64F;
407 case blas::ComputationType::kI32:
408 return CUDA_R_32I;
409 case blas::ComputationType::kComplexF32:
410 return CUDA_C_32F;
411 case blas::ComputationType::kComplexF64:
412 return CUDA_C_64F;
413 case blas::ComputationType::kTF32AsF32: // fall-through
414 case blas::ComputationType::kBF16AsF32:
415 // These cases are currently only supported in the blasLt routines, which
416 // use CUBLASComputationType() instead.
417 LOG(FATAL) << "Invalid value of blas::ComputationType.";
418 }
419 }
420
421 #if CUDA_VERSION >= 11000
CUBLASComputationType(blas::ComputationType ty)422 cublasComputeType_t CUBLASComputationType(blas::ComputationType ty) {
423 switch (ty) {
424 case blas::ComputationType::kF16:
425 return CUBLAS_COMPUTE_16F;
426 case blas::ComputationType::kF32: // fall-through
427 case blas::ComputationType::kComplexF32:
428 return CUBLAS_COMPUTE_32F;
429 case blas::ComputationType::kF64: // fall-through
430 case blas::ComputationType::kComplexF64:
431 return CUBLAS_COMPUTE_64F;
432 case blas::ComputationType::kI32:
433 return CUBLAS_COMPUTE_32I;
434 case blas::ComputationType::kTF32AsF32:
435 return CUBLAS_COMPUTE_32F_FAST_TF32;
436 case blas::ComputationType::kBF16AsF32:
437 return CUBLAS_COMPUTE_32F_FAST_16BF;
438 }
439 }
440 #endif // CUDA_VERSION >= 11000
441
GetScaleType(blas::DataType data_type,blas::ComputationType compute_type)442 blas::DataType GetScaleType(blas::DataType data_type,
443 blas::ComputationType compute_type) {
444 bool is_complex = data_type == blas::DataType::kComplexFloat ||
445 data_type == blas::DataType::kComplexDouble;
446 switch (compute_type) {
447 case blas::ComputationType::kF16:
448 return blas::DataType::kHalf;
449 case blas::ComputationType::kF32: // fall-through
450 case blas::ComputationType::kComplexF32: // fall-through
451 case blas::ComputationType::kTF32AsF32: // fall-through
452 case blas::ComputationType::kBF16AsF32:
453 return is_complex ? blas::DataType::kComplexFloat
454 : blas::DataType::kFloat;
455 case blas::ComputationType::kF64: // fall-through
456 case blas::ComputationType::kComplexF64:
457 return is_complex ? blas::DataType::kComplexDouble
458 : blas::DataType::kDouble;
459 case blas::ComputationType::kI32:
460 return blas::DataType::kInt32;
461 }
462 }
463
464 #if CUDA_VERSION >= 11000
CUBLASPointerMode(blas::PointerMode pointer_mode)465 cublasLtPointerMode_t CUBLASPointerMode(blas::PointerMode pointer_mode) {
466 switch (pointer_mode) {
467 case blas::PointerMode::kHost:
468 return CUBLASLT_POINTER_MODE_HOST;
469 case blas::PointerMode::kDevice:
470 return CUBLASLT_POINTER_MODE_DEVICE;
471 }
472 }
CUBLASEpilogue(blas::Epilogue epilogue)473 cublasLtEpilogue_t CUBLASEpilogue(blas::Epilogue epilogue) {
474 switch (epilogue) {
475 case blas::Epilogue::kDefault:
476 return CUBLASLT_EPILOGUE_DEFAULT;
477 case blas::Epilogue::kReLU:
478 return CUBLASLT_EPILOGUE_RELU;
479 case blas::Epilogue::kBias:
480 return CUBLASLT_EPILOGUE_BIAS;
481 case blas::Epilogue::kBiasThenReLU:
482 return CUBLASLT_EPILOGUE_RELU_BIAS;
483 }
484 }
485 #endif // CUDA_VERSION >= 11000
486
GetCUDADataType(blas::DataType ty)487 cudaDataType_t GetCUDADataType(blas::DataType ty) {
488 switch (ty) {
489 case blas::DataType::kHalf:
490 return CUDA_R_16F;
491 #if CUDA_VERSION >= 11000
492 case blas::DataType::kBF16:
493 return CUDA_R_16BF;
494 #endif
495 case blas::DataType::kFloat:
496 return CUDA_R_32F;
497 case blas::DataType::kDouble:
498 return CUDA_R_64F;
499 case blas::DataType::kInt8:
500 return CUDA_R_8I;
501 case blas::DataType::kInt32:
502 return CUDA_R_32I;
503 case blas::DataType::kComplexFloat:
504 return CUDA_C_32F;
505 case blas::DataType::kComplexDouble:
506 return CUDA_C_64F;
507 default:
508 LOG(FATAL) << "Invalid value of blas::DataType in GetCUDADataType";
509 }
510 }
511
GetDataTypeSizeBytes(blas::DataType ty)512 int GetDataTypeSizeBytes(blas::DataType ty) {
513 switch (ty) {
514 case blas::DataType::kHalf:
515 return 2;
516 case blas::DataType::kFloat:
517 return 4;
518 case blas::DataType::kDouble:
519 return 8;
520 case blas::DataType::kInt8:
521 return 1;
522 case blas::DataType::kInt32:
523 return 4;
524 case blas::DataType::kComplexFloat:
525 return 8;
526 case blas::DataType::kComplexDouble:
527 return 16;
528 default:
529 LOG(FATAL) << "Invalid value of blas::DataType in GetDataTypeSizeBytes";
530 }
531 }
532
533 } // namespace
534
535 template <typename FuncT, typename... Args>
DoBlasInternalImpl(FuncT cublas_func,Stream * stream,bool pointer_mode_host,cublasMath_t math_type,Args...args)536 port::Status CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
537 bool pointer_mode_host,
538 cublasMath_t math_type,
539 Args... args) {
540 absl::MutexLock lock(&mu_);
541
542 CHECK(blas_ != nullptr);
543 if (!SetStream(stream)) {
544 return port::InternalError("Failed setting stream");
545 }
546
547 #if CUDA_VERSION >= 9000
548 ScopedCublasMathMode math_mode{blas_};
549 #if CUBLAS_VER_MAJOR >= 11
550 if (math_type == CUBLAS_TF32_TENSOR_OP_MATH &&
551 tensorflow::tensor_float_32_execution_enabled()) {
552 #else
553 if (math_type == CUBLAS_TENSOR_OP_MATH) {
554 #endif
555 if (!math_mode.Init(math_type)) {
556 return port::InternalError("Failed initializing math mode");
557 }
558 }
559 #endif
560
561 gpu::ScopedActivateExecutorContext sac{parent_};
562 ScopedCublasPointerMode pointer_mode{blas_};
563 if (!pointer_mode.Init(pointer_mode_host ? CUBLAS_POINTER_MODE_HOST
564 : CUBLAS_POINTER_MODE_DEVICE)) {
565 return port::InternalError("Failed setting error mode");
566 }
567 cublasStatus_t ret = cublas_func(blas_, args...);
568 if (ret == CUBLAS_STATUS_SUCCESS) {
569 return port::Status::OK();
570 }
571 return port::InternalError(ToString(ret));
572 }
573
574 // cublas_func may be overloaded, so we need to figure out which one we really
575 // need to call based on the args. One way to do it is to wrap it in lambda.
576 #define AS_LAMBDA(func) \
577 [](auto &&... args) -> decltype( \
578 func(std::forward<decltype(args)>(args)...)) { \
579 return func(std::forward<decltype(args)>(args)...); \
580 }
581
582 bool CUDABlas::DoBlasAsum(Stream *stream, uint64 elem_count,
583 const DeviceMemory<float> &x, int incx,
584 DeviceMemory<float> *result) {
585 return DoBlasInternal(cublasSasum, stream, false /* = pointer_mode_host */,
586 elem_count, GpuMemory(x), incx,
587 GpuMemoryMutable(result));
588 }
589
590 bool CUDABlas::DoBlasAsum(Stream *stream, uint64 elem_count,
591 const DeviceMemory<double> &x, int incx,
592 DeviceMemory<double> *result) {
593 return DoBlasInternal(cublasDasum, stream, false /* = pointer_mode_host */,
594 elem_count, GpuMemory(x), incx,
595 GpuMemoryMutable(result));
596 }
597
598 bool CUDABlas::DoBlasAsum(Stream *stream, uint64 elem_count,
599 const DeviceMemory<std::complex<float>> &x, int incx,
600 DeviceMemory<float> *result) {
601 return DoBlasInternal(cublasScasum, stream, false /* = pointer_mode_host */,
602 elem_count, GpuComplex(GpuMemory(x)), incx,
603 GpuMemoryMutable(result));
604 }
605
606 bool CUDABlas::DoBlasAsum(Stream *stream, uint64 elem_count,
607 const DeviceMemory<std::complex<double>> &x, int incx,
608 DeviceMemory<double> *result) {
609 return DoBlasInternal(cublasDzasum, stream, false /* = pointer_mode_host */,
610 elem_count, GpuComplex(GpuMemory(x)), incx,
611 GpuMemoryMutable(result));
612 }
613
614 bool CUDABlas::DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha,
615 const DeviceMemory<float> &x, int incx,
616 DeviceMemory<float> *y, int incy) {
617 return DoBlasInternal(cublasSaxpy, stream, true /* = pointer_mode_host */,
618 elem_count, &alpha, GpuMemory(x), incx,
619 GpuMemoryMutable(y), incy);
620 }
621
622 bool CUDABlas::DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha,
623 const DeviceMemory<double> &x, int incx,
624 DeviceMemory<double> *y, int incy) {
625 return DoBlasInternal(cublasDaxpy, stream, true /* = pointer_mode_host */,
626 elem_count, &alpha, GpuMemory(x), incx,
627 GpuMemoryMutable(y), incy);
628 }
629
630 bool CUDABlas::DoBlasAxpy(Stream *stream, uint64 elem_count,
631 std::complex<float> alpha,
632 const DeviceMemory<std::complex<float>> &x, int incx,
633 DeviceMemory<std::complex<float>> *y, int incy) {
634 auto cb_alpha = GpuComplexValue(alpha);
635 return DoBlasInternal(cublasCaxpy, stream, true /* = pointer_mode_host */,
636 elem_count, GpuComplex(&cb_alpha),
637 GpuComplex(GpuMemory(x)), incx,
638 GpuComplex(GpuMemoryMutable(y)), incy);
639 }
640
641 bool CUDABlas::DoBlasAxpy(Stream *stream, uint64 elem_count,
642 std::complex<double> alpha,
643 const DeviceMemory<std::complex<double>> &x, int incx,
644 DeviceMemory<std::complex<double>> *y, int incy) {
645 auto cb_alpha = GpuComplexValue(alpha);
646 return DoBlasInternal(cublasZaxpy, stream, true /* = pointer_mode_host */,
647 elem_count, GpuComplex(&cb_alpha),
648 GpuComplex(GpuMemory(x)), incx,
649 GpuComplex(GpuMemoryMutable(y)), incy);
650 }
651
652 bool CUDABlas::DoBlasCopy(Stream *stream, uint64 elem_count,
653 const DeviceMemory<float> &x, int incx,
654 DeviceMemory<float> *y, int incy) {
655 return DoBlasInternal(cublasScopy, stream, true /* = pointer_mode_host */,
656 elem_count, GpuMemory(x), incx, GpuMemoryMutable(y),
657 incy);
658 }
659
660 bool CUDABlas::DoBlasCopy(Stream *stream, uint64 elem_count,
661 const DeviceMemory<double> &x, int incx,
662 DeviceMemory<double> *y, int incy) {
663 return DoBlasInternal(cublasDcopy, stream, true /* = pointer_mode_host */,
664 elem_count, GpuMemory(x), incx, GpuMemoryMutable(y),
665 incy);
666 }
667
668 bool CUDABlas::DoBlasCopy(Stream *stream, uint64 elem_count,
669 const DeviceMemory<std::complex<float>> &x, int incx,
670 DeviceMemory<std::complex<float>> *y, int incy) {
671 return DoBlasInternal(cublasCcopy, stream, true /* = pointer_mode_host */,
672 elem_count, GpuComplex(GpuMemory(x)), incx,
673 GpuComplex(GpuMemoryMutable(y)), incy);
674 }
675
676 bool CUDABlas::DoBlasCopy(Stream *stream, uint64 elem_count,
677 const DeviceMemory<std::complex<double>> &x, int incx,
678 DeviceMemory<std::complex<double>> *y, int incy) {
679 return DoBlasInternal(cublasZcopy, stream, true /* = pointer_mode_host */,
680 elem_count, GpuComplex(GpuMemory(x)), incx,
681 GpuComplex(GpuMemoryMutable(y)), incy);
682 }
683
684 bool CUDABlas::DoBlasDot(Stream *stream, uint64 elem_count,
685 const DeviceMemory<float> &x, int incx,
686 const DeviceMemory<float> &y, int incy,
687 DeviceMemory<float> *result) {
688 return DoBlasInternal(cublasSdot, stream, false /* = pointer_mode_host */,
689 elem_count, GpuMemory(x), incx, GpuMemory(y), incy,
690 GpuMemoryMutable(result));
691 }
692
693 bool CUDABlas::DoBlasDot(Stream *stream, uint64 elem_count,
694 const DeviceMemory<double> &x, int incx,
695 const DeviceMemory<double> &y, int incy,
696 DeviceMemory<double> *result) {
697 return DoBlasInternal(cublasDdot, stream, false /* = pointer_mode_host */,
698 elem_count, GpuMemory(x), incx, GpuMemory(y), incy,
699 GpuMemoryMutable(result));
700 }
701
702 bool CUDABlas::DoBlasDotc(Stream *stream, uint64 elem_count,
703 const DeviceMemory<std::complex<float>> &x, int incx,
704 const DeviceMemory<std::complex<float>> &y, int incy,
705 DeviceMemory<std::complex<float>> *result) {
706 return DoBlasInternal(cublasCdotc, stream, false /* = pointer_mode_host */,
707 elem_count, GpuComplex(GpuMemory(x)), incx,
708 GpuComplex(GpuMemory(y)), incy,
709 GpuComplex(GpuMemoryMutable(result)));
710 }
711
712 bool CUDABlas::DoBlasDotc(Stream *stream, uint64 elem_count,
713 const DeviceMemory<std::complex<double>> &x, int incx,
714 const DeviceMemory<std::complex<double>> &y, int incy,
715 DeviceMemory<std::complex<double>> *result) {
716 return DoBlasInternal(cublasZdotc, stream, false /* = pointer_mode_host */,
717 elem_count, GpuComplex(GpuMemory(x)), incx,
718 GpuComplex(GpuMemory(y)), incy,
719 GpuComplex(GpuMemoryMutable(result)));
720 }
721
722 bool CUDABlas::DoBlasDotu(Stream *stream, uint64 elem_count,
723 const DeviceMemory<std::complex<float>> &x, int incx,
724 const DeviceMemory<std::complex<float>> &y, int incy,
725 DeviceMemory<std::complex<float>> *result) {
726 return DoBlasInternal(cublasCdotu, stream, false /* = pointer_mode_host */,
727 elem_count, GpuComplex(GpuMemory(x)), incx,
728 GpuComplex(GpuMemory(y)), incy,
729 GpuComplex(GpuMemoryMutable(result)));
730 }
731
732 bool CUDABlas::DoBlasDotu(Stream *stream, uint64 elem_count,
733 const DeviceMemory<std::complex<double>> &x, int incx,
734 const DeviceMemory<std::complex<double>> &y, int incy,
735 DeviceMemory<std::complex<double>> *result) {
736 return DoBlasInternal(cublasZdotu, stream, false /* = pointer_mode_host */,
737 elem_count, GpuComplex(GpuMemory(x)), incx,
738 GpuComplex(GpuMemory(y)), incy,
739 GpuComplex(GpuMemoryMutable(result)));
740 }
741
742 bool CUDABlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
743 const DeviceMemory<float> &x, int incx,
744 DeviceMemory<float> *result) {
745 return DoBlasInternal(cublasSnrm2, stream, false /* = pointer_mode_host */,
746 elem_count, GpuMemory(x), incx,
747 GpuMemoryMutable(result));
748 }
749
750 bool CUDABlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
751 const DeviceMemory<double> &x, int incx,
752 DeviceMemory<double> *result) {
753 return DoBlasInternal(cublasDnrm2, stream, false /* = pointer_mode_host */,
754 elem_count, GpuMemory(x), incx,
755 GpuMemoryMutable(result));
756 }
757
758 bool CUDABlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
759 const DeviceMemory<std::complex<float>> &x, int incx,
760 DeviceMemory<float> *result) {
761 return DoBlasInternal(cublasScnrm2, stream, false /* = pointer_mode_host */,
762 elem_count, GpuComplex(GpuMemory(x)), incx,
763 GpuMemoryMutable(result));
764 }
765
766 bool CUDABlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
767 const DeviceMemory<std::complex<double>> &x, int incx,
768 DeviceMemory<double> *result) {
769 return DoBlasInternal(cublasDznrm2, stream, false /* = pointer_mode_host */,
770 elem_count, GpuComplex(GpuMemory(x)), incx,
771 GpuMemoryMutable(result));
772 }
773
774 bool CUDABlas::DoBlasRot(Stream *stream, uint64 elem_count,
775 DeviceMemory<float> *x, int incx,
776 DeviceMemory<float> *y, int incy, float c, float s) {
777 return DoBlasInternal(cublasSrot, stream, true /* = pointer_mode_host */,
778 elem_count, GpuMemoryMutable(x), incx,
779 GpuMemoryMutable(y), incy, &c, &s);
780 }
781
782 bool CUDABlas::DoBlasRot(Stream *stream, uint64 elem_count,
783 DeviceMemory<double> *x, int incx,
784 DeviceMemory<double> *y, int incy, double c,
785 double s) {
786 return DoBlasInternal(cublasDrot, stream, true /* = pointer_mode_host */,
787 elem_count, GpuMemoryMutable(x), incx,
788 GpuMemoryMutable(y), incy, &c, &s);
789 }
790
791 bool CUDABlas::DoBlasRot(Stream *stream, uint64 elem_count,
792 DeviceMemory<std::complex<float>> *x, int incx,
793 DeviceMemory<std::complex<float>> *y, int incy,
794 float c, float s) {
795 return DoBlasInternal(cublasCsrot, stream, true /* = pointer_mode_host */,
796 elem_count, GpuComplex(GpuMemoryMutable(x)), incx,
797 GpuComplex(GpuMemoryMutable(y)), incy, &c, &s);
798 }
799
800 bool CUDABlas::DoBlasRot(Stream *stream, uint64 elem_count,
801 DeviceMemory<std::complex<double>> *x, int incx,
802 DeviceMemory<std::complex<double>> *y, int incy,
803 double c, double s) {
804 return DoBlasInternal(cublasZdrot, stream, true /* = pointer_mode_host */,
805 elem_count, GpuComplex(GpuMemoryMutable(x)), incx,
806 GpuComplex(GpuMemoryMutable(y)), incy, &c, &s);
807 }
808
809 bool CUDABlas::DoBlasRotg(Stream *stream, DeviceMemory<float> *a,
810 DeviceMemory<float> *b, DeviceMemory<float> *c,
811 DeviceMemory<float> *s) {
812 return DoBlasInternal(cublasSrotg, stream, false /* = pointer_mode_host */,
813 GpuMemoryMutable(a), GpuMemoryMutable(b),
814 GpuMemoryMutable(c), GpuMemoryMutable(s));
815 }
816
817 bool CUDABlas::DoBlasRotg(Stream *stream, DeviceMemory<double> *a,
818 DeviceMemory<double> *b, DeviceMemory<double> *c,
819 DeviceMemory<double> *s) {
820 return DoBlasInternal(cublasDrotg, stream, false /* = pointer_mode_host */,
821 GpuComplex(GpuMemoryMutable(a)), GpuMemoryMutable(b),
822 GpuMemoryMutable(c), GpuMemoryMutable(s));
823 }
824
825 bool CUDABlas::DoBlasRotg(Stream *stream, DeviceMemory<std::complex<float>> *a,
826 DeviceMemory<std::complex<float>> *b,
827 DeviceMemory<float> *c,
828 DeviceMemory<std::complex<float>> *s) {
829 return DoBlasInternal(
830 cublasCrotg, stream, false /* = pointer_mode_host */,
831 GpuComplex(GpuMemoryMutable(a)), GpuComplex(GpuMemoryMutable(b)),
832 GpuComplex(GpuMemoryMutable(c)), GpuComplex(GpuMemoryMutable(s)));
833 }
834
835 bool CUDABlas::DoBlasRotg(Stream *stream, DeviceMemory<std::complex<double>> *a,
836 DeviceMemory<std::complex<double>> *b,
837 DeviceMemory<double> *c,
838 DeviceMemory<std::complex<double>> *s) {
839 return DoBlasInternal(
840 cublasZrotg, stream, false /* = pointer_mode_host */,
841 GpuComplex(GpuMemoryMutable(a)), GpuComplex(GpuMemoryMutable(b)),
842 GpuComplex(GpuMemoryMutable(c)), GpuComplex(GpuMemoryMutable(s)));
843 }
844
845 bool CUDABlas::DoBlasRotm(Stream *stream, uint64 elem_count,
846 DeviceMemory<float> *x, int incx,
847 DeviceMemory<float> *y, int incy,
848 const DeviceMemory<float> ¶m) {
849 return DoBlasInternal(cublasSrotm, stream, false /* = pointer_mode_host */,
850 elem_count, GpuMemoryMutable(x), incx,
851 GpuMemoryMutable(y), incy, GpuMemory(param));
852 }
853
854 bool CUDABlas::DoBlasRotm(Stream *stream, uint64 elem_count,
855 DeviceMemory<double> *x, int incx,
856 DeviceMemory<double> *y, int incy,
857 const DeviceMemory<double> ¶m) {
858 return DoBlasInternal(cublasDrotm, stream, false /* = pointer_mode_host */,
859 elem_count, GpuMemoryMutable(x), incx,
860 GpuMemoryMutable(y), incy, GpuMemory(param));
861 }
862
863 bool CUDABlas::DoBlasRotmg(Stream *stream, DeviceMemory<float> *d1,
864 DeviceMemory<float> *d2, DeviceMemory<float> *x1,
865 const DeviceMemory<float> &y1,
866 DeviceMemory<float> *param) {
867 return DoBlasInternal(cublasSrotmg, stream, false /* = pointer_mode_host */,
868 GpuMemoryMutable(d1), GpuMemoryMutable(d2),
869 GpuMemoryMutable(x1), GpuMemory(y1),
870 GpuMemoryMutable(param));
871 }
872
873 bool CUDABlas::DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1,
874 DeviceMemory<double> *d2, DeviceMemory<double> *x1,
875 const DeviceMemory<double> &y1,
876 DeviceMemory<double> *param) {
877 return DoBlasInternal(cublasDrotmg, stream, false /* = pointer_mode_host */,
878 GpuMemoryMutable(d1), GpuMemoryMutable(d2),
879 GpuMemoryMutable(x1), GpuMemory(y1),
880 GpuMemoryMutable(param));
881 }
882
883 bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
884 DeviceMemory<float> *x, int incx) {
885 return DoBlasInternal(cublasSscal, stream, true /* = pointer_mode_host */,
886 elem_count, &alpha, GpuMemoryMutable(x), incx);
887 }
888
889 bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
890 DeviceMemory<double> *x, int incx) {
891 return DoBlasInternal(cublasDscal, stream, true /* = pointer_mode_host */,
892 elem_count, &alpha, GpuMemoryMutable(x), incx);
893 }
894
895 bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
896 DeviceMemory<std::complex<float>> *x, int incx) {
897 return DoBlasInternal(cublasCsscal, stream, true /* = pointer_mode_host */,
898 elem_count, &alpha, GpuComplex(GpuMemoryMutable(x)),
899 incx);
900 }
901
902 bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
903 DeviceMemory<std::complex<double>> *x, int incx) {
904 return DoBlasInternal(cublasZdscal, stream, true /* = pointer_mode_host */,
905 elem_count, &alpha, GpuComplex(GpuMemoryMutable(x)),
906 incx);
907 }
908
909 bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count,
910 std::complex<float> alpha,
911 DeviceMemory<std::complex<float>> *x, int incx) {
912 auto cb_alpha = GpuComplexValue(alpha);
913 return DoBlasInternal(cublasCscal, stream, true /* = pointer_mode_host */,
914 elem_count, GpuComplex(&cb_alpha),
915 GpuComplex(GpuMemoryMutable(x)), incx);
916 }
917
918 bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count,
919 std::complex<double> alpha,
920 DeviceMemory<std::complex<double>> *x, int incx) {
921 auto cb_alpha = GpuComplexValue(alpha);
922 return DoBlasInternal(cublasZscal, stream, true /* = pointer_mode_host */,
923 elem_count, GpuComplex(&cb_alpha),
924 GpuComplex(GpuMemoryMutable(x)), incx);
925 }
926
927 bool CUDABlas::DoBlasSwap(Stream *stream, uint64 elem_count,
928 DeviceMemory<float> *x, int incx,
929 DeviceMemory<float> *y, int incy) {
930 return DoBlasInternal(cublasSswap, stream, true /* = pointer_mode_host */,
931 elem_count, GpuMemoryMutable(x), incx,
932 GpuMemoryMutable(y), incy);
933 }
934
935 bool CUDABlas::DoBlasSwap(Stream *stream, uint64 elem_count,
936 DeviceMemory<double> *x, int incx,
937 DeviceMemory<double> *y, int incy) {
938 return DoBlasInternal(cublasDswap, stream, true /* = pointer_mode_host */,
939 elem_count, GpuMemoryMutable(x), incx,
940 GpuMemoryMutable(y), incy);
941 }
942
943 bool CUDABlas::DoBlasSwap(Stream *stream, uint64 elem_count,
944 DeviceMemory<std::complex<float>> *x, int incx,
945 DeviceMemory<std::complex<float>> *y, int incy) {
946 return DoBlasInternal(cublasCswap, stream, true /* = pointer_mode_host */,
947 elem_count, GpuComplex(GpuMemoryMutable(x)), incx,
948 GpuComplex(GpuMemoryMutable(y)), incy);
949 }
950
951 bool CUDABlas::DoBlasSwap(Stream *stream, uint64 elem_count,
952 DeviceMemory<std::complex<double>> *x, int incx,
953 DeviceMemory<std::complex<double>> *y, int incy) {
954 return DoBlasInternal(cublasZswap, stream, true /* = pointer_mode_host */,
955 elem_count, GpuComplex(GpuMemoryMutable(x)), incx,
956 GpuComplex(GpuMemoryMutable(y)), incy);
957 }
958
959 bool CUDABlas::DoBlasIamax(Stream *stream, uint64 elem_count,
960 const DeviceMemory<float> &x, int incx,
961 DeviceMemory<int> *result) {
962 return DoBlasInternal(cublasIsamax, stream, false /* = pointer_mode_host */,
963 elem_count, GpuMemory(x), incx,
964 GpuMemoryMutable(result));
965 }
966
967 bool CUDABlas::DoBlasIamax(Stream *stream, uint64 elem_count,
968 const DeviceMemory<double> &x, int incx,
969 DeviceMemory<int> *result) {
970 return DoBlasInternal(cublasIdamax, stream, false /* = pointer_mode_host */,
971 elem_count, GpuMemory(x), incx,
972 GpuMemoryMutable(result));
973 }
974
975 bool CUDABlas::DoBlasIamax(Stream *stream, uint64 elem_count,
976 const DeviceMemory<std::complex<float>> &x, int incx,
977 DeviceMemory<int> *result) {
978 return DoBlasInternal(cublasIcamax, stream, false /* = pointer_mode_host */,
979 elem_count, GpuComplex(GpuMemory(x)), incx,
980 GpuMemoryMutable(result));
981 }
982
983 bool CUDABlas::DoBlasIamax(Stream *stream, uint64 elem_count,
984 const DeviceMemory<std::complex<double>> &x,
985 int incx, DeviceMemory<int> *result) {
986 return DoBlasInternal(cublasIzamax, stream, false /* = pointer_mode_host */,
987 elem_count, GpuComplex(GpuMemory(x)), incx,
988 GpuMemoryMutable(result));
989 }
990
991 bool CUDABlas::DoBlasIamin(Stream *stream, uint64 elem_count,
992 const DeviceMemory<float> &x, int incx,
993 DeviceMemory<int> *result) {
994 return DoBlasInternal(cublasIsamin, stream, false /* = pointer_mode_host */,
995 elem_count, GpuComplex(GpuMemory(x)), incx,
996 GpuMemoryMutable(result));
997 }
998
999 bool CUDABlas::DoBlasIamin(Stream *stream, uint64 elem_count,
1000 const DeviceMemory<double> &x, int incx,
1001 DeviceMemory<int> *result) {
1002 return DoBlasInternal(cublasIdamin, stream, false /* = pointer_mode_host */,
1003 elem_count, GpuComplex(GpuMemory(x)), incx,
1004 GpuMemoryMutable(result));
1005 }
1006
1007 bool CUDABlas::DoBlasIamin(Stream *stream, uint64 elem_count,
1008 const DeviceMemory<std::complex<float>> &x, int incx,
1009 DeviceMemory<int> *result) {
1010 return DoBlasInternal(cublasIcamin, stream, false /* = pointer_mode_host */,
1011 elem_count, GpuComplex(GpuMemory(x)), incx,
1012 GpuMemoryMutable(result));
1013 }
1014
1015 bool CUDABlas::DoBlasIamin(Stream *stream, uint64 elem_count,
1016 const DeviceMemory<std::complex<double>> &x,
1017 int incx, DeviceMemory<int> *result) {
1018 return DoBlasInternal(cublasIzamin, stream, false /* = pointer_mode_host */,
1019 elem_count, GpuComplex(GpuMemory(x)), incx,
1020 GpuMemoryMutable(result));
1021 }
1022
1023 bool CUDABlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
1024 uint64 n, uint64 kl, uint64 ku, float alpha,
1025 const DeviceMemory<float> &a, int lda,
1026 const DeviceMemory<float> &x, int incx, float beta,
1027 DeviceMemory<float> *y, int incy) {
1028 return DoBlasInternal(cublasSgbmv, stream, true /* = pointer_mode_host */,
1029 CUDABlasTranspose(trans), m, n, kl, ku, &alpha,
1030 GpuMemory(a), lda, GpuMemory(x), incx, &beta,
1031 GpuMemoryMutable(y), incy);
1032 }
1033
1034 bool CUDABlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
1035 uint64 n, uint64 kl, uint64 ku, double alpha,
1036 const DeviceMemory<double> &a, int lda,
1037 const DeviceMemory<double> &x, int incx, double beta,
1038 DeviceMemory<double> *y, int incy) {
1039 return DoBlasInternal(cublasDgbmv, stream, true /* = pointer_mode_host */,
1040 CUDABlasTranspose(trans), m, n, kl, ku, &alpha,
1041 GpuMemory(a), lda, GpuMemory(x), incx, &beta,
1042 GpuMemoryMutable(y), incy);
1043 }
1044
1045 bool CUDABlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
1046 uint64 n, uint64 kl, uint64 ku,
1047 std::complex<float> alpha,
1048 const DeviceMemory<std::complex<float>> &a, int lda,
1049 const DeviceMemory<std::complex<float>> &x, int incx,
1050 std::complex<float> beta,
1051 DeviceMemory<std::complex<float>> *y, int incy) {
1052 auto cb_alpha = GpuComplexValue(alpha);
1053 auto cb_beta = GpuComplexValue(beta);
1054 return DoBlasInternal(cublasCgbmv, stream, true /* = pointer_mode_host */,
1055 CUDABlasTranspose(trans), m, n, kl, ku,
1056 GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
1057 GpuComplex(GpuMemory(x)), incx, GpuComplex(&cb_beta),
1058 GpuComplex(GpuMemoryMutable(y)), incy);
1059 }
1060
1061 bool CUDABlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
1062 uint64 n, uint64 kl, uint64 ku,
1063 std::complex<double> alpha,
1064 const DeviceMemory<std::complex<double>> &a, int lda,
1065 const DeviceMemory<std::complex<double>> &x, int incx,
1066 std::complex<double> beta,
1067 DeviceMemory<std::complex<double>> *y, int incy) {
1068 auto cb_alpha = GpuComplexValue(alpha);
1069 auto cb_beta = GpuComplexValue(beta);
1070 return DoBlasInternal(cublasZgbmv, stream, true /* = pointer_mode_host */,
1071 CUDABlasTranspose(trans), m, n, kl, ku,
1072 GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
1073 GpuComplex(GpuMemory(x)), incx, GpuComplex(&cb_beta),
1074 GpuComplex(GpuMemoryMutable(y)), incy);
1075 }
1076
1077 bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
1078 uint64 n, float alpha, const DeviceMemory<float> &a,
1079 int lda, const DeviceMemory<float> &x, int incx,
1080 float beta, DeviceMemory<float> *y, int incy) {
1081 return DoBlasInternal(cublasSgemv, stream, true /* = pointer_mode_host */,
1082 CUDABlasTranspose(trans), m, n, &alpha, GpuMemory(a),
1083 lda, GpuMemory(x), incx, &beta, GpuMemoryMutable(y),
1084 incy);
1085 }
1086
1087 bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
1088 uint64 n, double alpha, const DeviceMemory<double> &a,
1089 int lda, const DeviceMemory<double> &x, int incx,
1090 double beta, DeviceMemory<double> *y, int incy) {
1091 return DoBlasInternal(cublasDgemv, stream, true /* = pointer_mode_host */,
1092 CUDABlasTranspose(trans), m, n, &alpha, GpuMemory(a),
1093 lda, GpuMemory(x), incx, &beta, GpuMemoryMutable(y),
1094 incy);
1095 }
1096
1097 bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
1098 uint64 n, std::complex<float> alpha,
1099 const DeviceMemory<std::complex<float>> &a, int lda,
1100 const DeviceMemory<std::complex<float>> &x, int incx,
1101 std::complex<float> beta,
1102 DeviceMemory<std::complex<float>> *y, int incy) {
1103 auto cb_alpha = GpuComplexValue(alpha);
1104 auto cb_beta = GpuComplexValue(beta);
1105 return DoBlasInternal(cublasCgemv, stream, true /* = pointer_mode_host */,
1106 CUDABlasTranspose(trans), m, n, GpuComplex(&cb_alpha),
1107 GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
1108 incx, GpuComplex(&cb_beta),
1109 GpuComplex(GpuMemoryMutable(y)), incy);
1110 }
1111
1112 bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
1113 uint64 n, std::complex<double> alpha,
1114 const DeviceMemory<std::complex<double>> &a, int lda,
1115 const DeviceMemory<std::complex<double>> &x, int incx,
1116 std::complex<double> beta,
1117 DeviceMemory<std::complex<double>> *y, int incy) {
1118 auto cb_alpha = GpuComplexValue(alpha);
1119 auto cb_beta = GpuComplexValue(beta);
1120 return DoBlasInternal(cublasZgemv, stream, true /* = pointer_mode_host */,
1121 CUDABlasTranspose(trans), m, n, GpuComplex(&cb_alpha),
1122 GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
1123 incx, GpuComplex(&cb_beta),
1124 GpuComplex(GpuMemoryMutable(y)), incy);
1125 }
1126
1127 bool CUDABlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha,
1128 const DeviceMemory<float> &x, int incx,
1129 const DeviceMemory<float> &y, int incy,
1130 DeviceMemory<float> *a, int lda) {
1131 return DoBlasInternal(cublasSger, stream, true /* = pointer_mode_host */, m,
1132 n, &alpha, GpuMemory(x), incx, GpuMemory(y), incy,
1133 GpuMemoryMutable(a), lda);
1134 }
1135
1136 bool CUDABlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha,
1137 const DeviceMemory<double> &x, int incx,
1138 const DeviceMemory<double> &y, int incy,
1139 DeviceMemory<double> *a, int lda) {
1140 return DoBlasInternal(cublasDger, stream, true /* = pointer_mode_host */, m,
1141 n, &alpha, GpuMemory(x), incx, GpuMemory(y), incy,
1142 GpuMemoryMutable(a), lda);
1143 }
1144
1145 bool CUDABlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n,
1146 std::complex<float> alpha,
1147 const DeviceMemory<std::complex<float>> &x, int incx,
1148 const DeviceMemory<std::complex<float>> &y, int incy,
1149 DeviceMemory<std::complex<float>> *a, int lda) {
1150 auto cb_alpha = GpuComplexValue(alpha);
1151 return DoBlasInternal(cublasCgerc, stream, true /* = pointer_mode_host */, m,
1152 n, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(x)),
1153 incx, GpuComplex(GpuMemory(y)), incy,
1154 GpuComplex(GpuMemoryMutable(a)), lda);
1155 }
1156
1157 bool CUDABlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n,
1158 std::complex<double> alpha,
1159 const DeviceMemory<std::complex<double>> &x, int incx,
1160 const DeviceMemory<std::complex<double>> &y, int incy,
1161 DeviceMemory<std::complex<double>> *a, int lda) {
1162 auto cb_alpha = GpuComplexValue(alpha);
1163 return DoBlasInternal(cublasZgerc, stream, true /* = pointer_mode_host */, m,
1164 n, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(x)),
1165 incx, GpuComplex(GpuMemory(y)), incy,
1166 GpuComplex(GpuMemoryMutable(a)), lda);
1167 }
1168
1169 bool CUDABlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n,
1170 std::complex<float> alpha,
1171 const DeviceMemory<std::complex<float>> &x, int incx,
1172 const DeviceMemory<std::complex<float>> &y, int incy,
1173 DeviceMemory<std::complex<float>> *a, int lda) {
1174 auto cb_alpha = GpuComplexValue(alpha);
1175 return DoBlasInternal(cublasCgeru, stream, true /* = pointer_mode_host */, m,
1176 n, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(x)),
1177 incx, GpuComplex(GpuMemory(y)), incy,
1178 GpuComplex(GpuMemoryMutable(a)), lda);
1179 }
1180
1181 bool CUDABlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n,
1182 std::complex<double> alpha,
1183 const DeviceMemory<std::complex<double>> &x, int incx,
1184 const DeviceMemory<std::complex<double>> &y, int incy,
1185 DeviceMemory<std::complex<double>> *a, int lda) {
1186 auto cb_alpha = GpuComplexValue(alpha);
1187 return DoBlasInternal(cublasZgeru, stream, true /* = pointer_mode_host */, m,
1188 n, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(x)),
1189 incx, GpuComplex(GpuMemory(y)), incy,
1190 GpuComplex(GpuMemoryMutable(a)), lda);
1191 }
1192
1193 bool CUDABlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1194 uint64 k, std::complex<float> alpha,
1195 const DeviceMemory<std::complex<float>> &a, int lda,
1196 const DeviceMemory<std::complex<float>> &x, int incx,
1197 std::complex<float> beta,
1198 DeviceMemory<std::complex<float>> *y, int incy) {
1199 auto cb_alpha = GpuComplexValue(alpha);
1200 auto cb_beta = GpuComplexValue(beta);
1201 return DoBlasInternal(cublasChbmv, stream, true /* = pointer_mode_host */,
1202 CUDABlasUpperLower(uplo), n, k, GpuComplex(&cb_alpha),
1203 GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
1204 incx, GpuComplex(&cb_beta),
1205 GpuComplex(GpuMemoryMutable(y)), incy);
1206 }
1207
1208 bool CUDABlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1209 uint64 k, std::complex<double> alpha,
1210 const DeviceMemory<std::complex<double>> &a, int lda,
1211 const DeviceMemory<std::complex<double>> &x, int incx,
1212 std::complex<double> beta,
1213 DeviceMemory<std::complex<double>> *y, int incy) {
1214 auto cb_alpha = GpuComplexValue(alpha);
1215 auto cb_beta = GpuComplexValue(beta);
1216 return DoBlasInternal(cublasZhbmv, stream, true /* = pointer_mode_host */,
1217 CUDABlasUpperLower(uplo), n, k, GpuComplex(&cb_alpha),
1218 GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
1219 incx, GpuComplex(&cb_beta),
1220 GpuComplex(GpuMemoryMutable(y)), incy);
1221 }
1222
1223 bool CUDABlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
1224 std::complex<float> alpha,
1225 const DeviceMemory<std::complex<float>> &a, int lda,
1226 const DeviceMemory<std::complex<float>> &x, int incx,
1227 std::complex<float> beta,
1228 DeviceMemory<std::complex<float>> *y, int incy) {
1229 auto cb_alpha = GpuComplexValue(alpha);
1230 auto cb_beta = GpuComplexValue(beta);
1231 return DoBlasInternal(cublasChemv, stream, true /* = pointer_mode_host */,
1232 CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
1233 GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
1234 incx, GpuComplex(&cb_beta),
1235 GpuComplex(GpuMemoryMutable(y)), incy);
1236 }
1237
1238 bool CUDABlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
1239 std::complex<double> alpha,
1240 const DeviceMemory<std::complex<double>> &a, int lda,
1241 const DeviceMemory<std::complex<double>> &x, int incx,
1242 std::complex<double> beta,
1243 DeviceMemory<std::complex<double>> *y, int incy) {
1244 auto cb_alpha = GpuComplexValue(alpha);
1245 auto cb_beta = GpuComplexValue(beta);
1246 return DoBlasInternal(cublasZhemv, stream, true /* = pointer_mode_host */,
1247 CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
1248 GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
1249 incx, GpuComplex(&cb_beta),
1250 GpuComplex(GpuMemoryMutable(y)), incy);
1251 }
1252
1253 bool CUDABlas::DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
1254 float alpha,
1255 const DeviceMemory<std::complex<float>> &x, int incx,
1256 DeviceMemory<std::complex<float>> *a, int lda) {
1257 return DoBlasInternal(cublasCher, stream, true /* = pointer_mode_host */,
1258 CUDABlasUpperLower(uplo), n, &alpha,
1259 GpuComplex(GpuMemory(x)), incx,
1260 GpuComplex(GpuMemoryMutable(a)), lda);
1261 }
1262
1263 bool CUDABlas::DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
1264 double alpha,
1265 const DeviceMemory<std::complex<double>> &x, int incx,
1266 DeviceMemory<std::complex<double>> *a, int lda) {
1267 return DoBlasInternal(cublasZher, stream, true /* = pointer_mode_host */,
1268 CUDABlasUpperLower(uplo), n, &alpha,
1269 GpuComplex(GpuMemory(x)), incx,
1270 GpuComplex(GpuMemoryMutable(a)), lda);
1271 }
1272
1273 bool CUDABlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
1274 std::complex<float> alpha,
1275 const DeviceMemory<std::complex<float>> &x, int incx,
1276 const DeviceMemory<std::complex<float>> &y, int incy,
1277 DeviceMemory<std::complex<float>> *a, int lda) {
1278 auto cb_alpha = GpuComplexValue(alpha);
1279 return DoBlasInternal(cublasCher2, stream, true /* = pointer_mode_host */,
1280 CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
1281 GpuComplex(GpuMemory(x)), incx,
1282 GpuComplex(GpuMemory(y)), incy,
1283 GpuComplex(GpuMemoryMutable(a)), lda);
1284 }
1285
1286 bool CUDABlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
1287 std::complex<double> alpha,
1288 const DeviceMemory<std::complex<double>> &x, int incx,
1289 const DeviceMemory<std::complex<double>> &y, int incy,
1290 DeviceMemory<std::complex<double>> *a, int lda) {
1291 auto cb_alpha = GpuComplexValue(alpha);
1292 return DoBlasInternal(cublasZher2, stream, true /* = pointer_mode_host */,
1293 CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
1294 GpuComplex(GpuMemory(x)), incx,
1295 GpuComplex(GpuMemory(y)), incy,
1296 GpuComplex(GpuMemoryMutable(a)), lda);
1297 }
1298
1299 bool CUDABlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1300 std::complex<float> alpha,
1301 const DeviceMemory<std::complex<float>> &ap,
1302 const DeviceMemory<std::complex<float>> &x, int incx,
1303 std::complex<float> beta,
1304 DeviceMemory<std::complex<float>> *y, int incy) {
1305 auto cb_alpha = GpuComplexValue(alpha);
1306 auto cb_beta = GpuComplexValue(beta);
1307 return DoBlasInternal(cublasChpmv, stream, true /* = pointer_mode_host */,
1308 CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
1309 GpuComplex(GpuMemory(ap)), GpuComplex(GpuMemory(x)),
1310 incx, GpuComplex(&cb_beta),
1311 GpuComplex(GpuMemoryMutable(y)), incy);
1312 }
1313
1314 bool CUDABlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1315 std::complex<double> alpha,
1316 const DeviceMemory<std::complex<double>> &ap,
1317 const DeviceMemory<std::complex<double>> &x, int incx,
1318 std::complex<double> beta,
1319 DeviceMemory<std::complex<double>> *y, int incy) {
1320 auto cb_alpha = GpuComplexValue(alpha);
1321 auto cb_beta = GpuComplexValue(beta);
1322 return DoBlasInternal(cublasZhpmv, stream, true /* = pointer_mode_host */,
1323 CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
1324 GpuComplex(GpuMemory(ap)), GpuComplex(GpuMemory(x)),
1325 incx, GpuComplex(&cb_beta),
1326 GpuComplex(GpuMemoryMutable(y)), incy);
1327 }
1328
1329 bool CUDABlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
1330 float alpha,
1331 const DeviceMemory<std::complex<float>> &x, int incx,
1332 DeviceMemory<std::complex<float>> *ap) {
1333 return DoBlasInternal(cublasChpr, stream, true /* = pointer_mode_host */,
1334 CUDABlasUpperLower(uplo), n, &alpha,
1335 GpuComplex(GpuMemory(x)), incx,
1336 GpuComplex(GpuMemoryMutable(ap)));
1337 }
1338
1339 bool CUDABlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
1340 double alpha,
1341 const DeviceMemory<std::complex<double>> &x, int incx,
1342 DeviceMemory<std::complex<double>> *ap) {
1343 return DoBlasInternal(cublasZhpr, stream, true /* = pointer_mode_host */,
1344 CUDABlasUpperLower(uplo), n, &alpha,
1345 GpuComplex(GpuMemory(x)), incx,
1346 GpuComplex(GpuMemoryMutable(ap)));
1347 }
1348
1349 bool CUDABlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1350 std::complex<float> alpha,
1351 const DeviceMemory<std::complex<float>> &x, int incx,
1352 const DeviceMemory<std::complex<float>> &y, int incy,
1353 DeviceMemory<std::complex<float>> *ap) {
1354 auto cb_alpha = GpuComplexValue(alpha);
1355 return DoBlasInternal(cublasChpr2, stream, true /* = pointer_mode_host */,
1356 CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
1357 GpuComplex(GpuMemory(x)), incx,
1358 GpuComplex(GpuMemory(y)), incy,
1359 GpuComplex(GpuMemoryMutable(ap)));
1360 }
1361
1362 bool CUDABlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1363 std::complex<double> alpha,
1364 const DeviceMemory<std::complex<double>> &x, int incx,
1365 const DeviceMemory<std::complex<double>> &y, int incy,
1366 DeviceMemory<std::complex<double>> *ap) {
1367 auto cb_alpha = GpuComplexValue(alpha);
1368 return DoBlasInternal(cublasZhpr2, stream, true /* = pointer_mode_host */,
1369 CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
1370 GpuComplex(GpuMemory(x)), incx,
1371 GpuComplex(GpuMemory(y)), incy,
1372 GpuComplex(GpuMemoryMutable(ap)));
1373 }
1374
1375 bool CUDABlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1376 uint64 k, float alpha, const DeviceMemory<float> &a,
1377 int lda, const DeviceMemory<float> &x, int incx,
1378 float beta, DeviceMemory<float> *y, int incy) {
1379 return DoBlasInternal(cublasSsbmv, stream, true /* = pointer_mode_host */,
1380 CUDABlasUpperLower(uplo), n, k, &alpha, GpuMemory(a),
1381 lda, GpuMemory(x), incx, &beta, GpuMemoryMutable(y),
1382 incy);
1383 }
1384
1385 bool CUDABlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1386 uint64 k, double alpha, const DeviceMemory<double> &a,
1387 int lda, const DeviceMemory<double> &x, int incx,
1388 double beta, DeviceMemory<double> *y, int incy) {
1389 return DoBlasInternal(cublasDsbmv, stream, true /* = pointer_mode_host */,
1390 CUDABlasUpperLower(uplo), n, k, &alpha, GpuMemory(a),
1391 lda, GpuMemory(x), incx, &beta, GpuMemoryMutable(y),
1392 incy);
1393 }
1394
1395 bool CUDABlas::DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1396 float alpha, const DeviceMemory<float> &ap,
1397 const DeviceMemory<float> &x, int incx, float beta,
1398 DeviceMemory<float> *y, int incy) {
1399 return DoBlasInternal(cublasSspmv, stream, true /* = pointer_mode_host */,
1400 CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(ap),
1401 GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy);
1402 }
1403
1404 bool CUDABlas::DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1405 double alpha, const DeviceMemory<double> &ap,
1406 const DeviceMemory<double> &x, int incx, double beta,
1407 DeviceMemory<double> *y, int incy) {
1408 return DoBlasInternal(cublasDspmv, stream, true /* = pointer_mode_host */,
1409 CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(ap),
1410 GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy);
1411 }
1412
1413 bool CUDABlas::DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
1414 float alpha, const DeviceMemory<float> &x, int incx,
1415 DeviceMemory<float> *ap) {
1416 return DoBlasInternal(cublasSspr, stream, true /* = pointer_mode_host */,
1417 CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1418 GpuMemoryMutable(ap));
1419 }
1420
1421 bool CUDABlas::DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
1422 double alpha, const DeviceMemory<double> &x, int incx,
1423 DeviceMemory<double> *ap) {
1424 return DoBlasInternal(cublasDspr, stream, true /* = pointer_mode_host */,
1425 CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1426 GpuMemoryMutable(ap));
1427 }
1428
1429 bool CUDABlas::DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1430 float alpha, const DeviceMemory<float> &x, int incx,
1431 const DeviceMemory<float> &y, int incy,
1432 DeviceMemory<float> *ap) {
1433 return DoBlasInternal(cublasSspr2, stream, true /* = pointer_mode_host */,
1434 CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1435 GpuMemory(y), incy, GpuMemoryMutable(ap));
1436 }
1437
1438 bool CUDABlas::DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1439 double alpha, const DeviceMemory<double> &x, int incx,
1440 const DeviceMemory<double> &y, int incy,
1441 DeviceMemory<double> *ap) {
1442 return DoBlasInternal(cublasDspr2, stream, true /* = pointer_mode_host */,
1443 CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1444 GpuMemory(y), incy, GpuMemoryMutable(ap));
1445 }
1446
1447 bool CUDABlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
1448 float alpha, const DeviceMemory<float> &a, int lda,
1449 const DeviceMemory<float> &x, int incx, float beta,
1450 DeviceMemory<float> *y, int incy) {
1451 return DoBlasInternal(cublasSsymv, stream, true /* = pointer_mode_host */,
1452 CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(a), lda,
1453 GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy);
1454 }
1455
1456 bool CUDABlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
1457 double alpha, const DeviceMemory<double> &a, int lda,
1458 const DeviceMemory<double> &x, int incx, double beta,
1459 DeviceMemory<double> *y, int incy) {
1460 return DoBlasInternal(cublasDsymv, stream, true /* = pointer_mode_host */,
1461 CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(a), lda,
1462 GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy);
1463 }
1464
1465 bool CUDABlas::DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
1466 float alpha, const DeviceMemory<float> &x, int incx,
1467 DeviceMemory<float> *a, int lda) {
1468 return DoBlasInternal(cublasSsyr, stream, true /* = pointer_mode_host */,
1469 CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1470 GpuMemoryMutable(a), lda);
1471 }
1472
1473 bool CUDABlas::DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
1474 double alpha, const DeviceMemory<double> &x, int incx,
1475 DeviceMemory<double> *a, int lda) {
1476 return DoBlasInternal(cublasDsyr, stream, true /* = pointer_mode_host */,
1477 CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1478 GpuMemoryMutable(a), lda);
1479 }
1480
1481 bool CUDABlas::DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1482 float alpha, const DeviceMemory<float> &x, int incx,
1483 const DeviceMemory<float> &y, int incy,
1484 DeviceMemory<float> *a, int lda) {
1485 return DoBlasInternal(cublasSsyr2, stream, true /* = pointer_mode_host */,
1486 CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1487 GpuMemory(y), incy, GpuMemoryMutable(a), lda);
1488 }
1489
1490 bool CUDABlas::DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1491 double alpha, const DeviceMemory<double> &x, int incx,
1492 const DeviceMemory<double> &y, int incy,
1493 DeviceMemory<double> *a, int lda) {
1494 return DoBlasInternal(cublasDsyr2, stream, true /* = pointer_mode_host */,
1495 CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1496 GpuMemory(y), incy, GpuMemoryMutable(a), lda);
1497 }
1498
1499 bool CUDABlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
1500 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1501 uint64 k, const DeviceMemory<float> &a, int lda,
1502 DeviceMemory<float> *x, int incx) {
1503 return DoBlasInternal(cublasStbmv, stream, true /* = pointer_mode_host */,
1504 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1505 CUDABlasDiagonal(diag), n, k, GpuMemory(a), lda,
1506 GpuMemoryMutable(x), incx);
1507 }
1508
1509 bool CUDABlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
1510 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1511 uint64 k, const DeviceMemory<double> &a, int lda,
1512 DeviceMemory<double> *x, int incx) {
1513 return DoBlasInternal(cublasDtbmv, stream, true /* = pointer_mode_host */,
1514 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1515 CUDABlasDiagonal(diag), n, k, GpuMemory(a), lda,
1516 GpuMemoryMutable(x), incx);
1517 }
1518
1519 bool CUDABlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
1520 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1521 uint64 k, const DeviceMemory<std::complex<float>> &a,
1522 int lda, DeviceMemory<std::complex<float>> *x,
1523 int incx) {
1524 return DoBlasInternal(cublasCtbmv, stream, true /* = pointer_mode_host */,
1525 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1526 CUDABlasDiagonal(diag), n, k, GpuComplex(GpuMemory(a)),
1527 lda, GpuComplex(GpuMemoryMutable(x)), incx);
1528 }
1529
1530 bool CUDABlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
1531 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1532 uint64 k, const DeviceMemory<std::complex<double>> &a,
1533 int lda, DeviceMemory<std::complex<double>> *x,
1534 int incx) {
1535 return DoBlasInternal(cublasZtbmv, stream, true /* = pointer_mode_host */,
1536 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1537 CUDABlasDiagonal(diag), n, k, GpuComplex(GpuMemory(a)),
1538 lda, GpuComplex(GpuMemoryMutable(x)), incx);
1539 }
1540
1541 bool CUDABlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
1542 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1543 uint64 k, const DeviceMemory<float> &a, int lda,
1544 DeviceMemory<float> *x, int incx) {
1545 return DoBlasInternal(cublasStbsv, stream, true /* = pointer_mode_host */,
1546 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1547 CUDABlasDiagonal(diag), n, k, GpuMemory(a), lda,
1548 GpuMemoryMutable(x), incx);
1549 }
1550
1551 bool CUDABlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
1552 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1553 uint64 k, const DeviceMemory<double> &a, int lda,
1554 DeviceMemory<double> *x, int incx) {
1555 return DoBlasInternal(cublasDtbsv, stream, true /* = pointer_mode_host */,
1556 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1557 CUDABlasDiagonal(diag), n, k, GpuMemory(a), lda,
1558 GpuMemoryMutable(x), incx);
1559 }
1560
1561 bool CUDABlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
1562 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1563 uint64 k, const DeviceMemory<std::complex<float>> &a,
1564 int lda, DeviceMemory<std::complex<float>> *x,
1565 int incx) {
1566 return DoBlasInternal(cublasCtbsv, stream, true /* = pointer_mode_host */,
1567 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1568 CUDABlasDiagonal(diag), n, k, GpuComplex(GpuMemory(a)),
1569 lda, GpuComplex(GpuMemoryMutable(x)), incx);
1570 }
1571
1572 bool CUDABlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
1573 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1574 uint64 k, const DeviceMemory<std::complex<double>> &a,
1575 int lda, DeviceMemory<std::complex<double>> *x,
1576 int incx) {
1577 return DoBlasInternal(cublasZtbsv, stream, true /* = pointer_mode_host */,
1578 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1579 CUDABlasDiagonal(diag), n, k, GpuComplex(GpuMemory(a)),
1580 lda, GpuComplex(GpuMemoryMutable(x)), incx);
1581 }
1582
1583 bool CUDABlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
1584 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1585 const DeviceMemory<float> &ap, DeviceMemory<float> *x,
1586 int incx) {
1587 return DoBlasInternal(cublasStpmv, stream, true /* = pointer_mode_host */,
1588 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1589 CUDABlasDiagonal(diag), n, GpuMemory(ap),
1590 GpuMemoryMutable(x), incx);
1591 }
1592
1593 bool CUDABlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
1594 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1595 const DeviceMemory<double> &ap,
1596 DeviceMemory<double> *x, int incx) {
1597 return DoBlasInternal(cublasDtpmv, stream, true /* = pointer_mode_host */,
1598 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1599 CUDABlasDiagonal(diag), n, GpuMemory(ap),
1600 GpuMemoryMutable(x), incx);
1601 }
1602
1603 bool CUDABlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
1604 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1605 const DeviceMemory<std::complex<float>> &ap,
1606 DeviceMemory<std::complex<float>> *x, int incx) {
1607 return DoBlasInternal(cublasCtpmv, stream, true /* = pointer_mode_host */,
1608 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1609 CUDABlasDiagonal(diag), n, GpuComplex(GpuMemory(ap)),
1610 GpuComplex(GpuMemoryMutable(x)), incx);
1611 }
1612
1613 bool CUDABlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
1614 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1615 const DeviceMemory<std::complex<double>> &ap,
1616 DeviceMemory<std::complex<double>> *x, int incx) {
1617 return DoBlasInternal(cublasZtpmv, stream, true /* = pointer_mode_host */,
1618 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1619 CUDABlasDiagonal(diag), n, GpuComplex(GpuMemory(ap)),
1620 GpuComplex(GpuMemoryMutable(x)), incx);
1621 }
1622
1623 bool CUDABlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
1624 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1625 const DeviceMemory<float> &ap, DeviceMemory<float> *x,
1626 int incx) {
1627 return DoBlasInternal(cublasStpsv, stream, true /* = pointer_mode_host */,
1628 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1629 CUDABlasDiagonal(diag), n, GpuMemory(ap),
1630 GpuMemoryMutable(x), incx);
1631 }
1632
1633 bool CUDABlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
1634 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1635 const DeviceMemory<double> &ap,
1636 DeviceMemory<double> *x, int incx) {
1637 return DoBlasInternal(cublasDtpsv, stream, true /* = pointer_mode_host */,
1638 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1639 CUDABlasDiagonal(diag), n, GpuMemory(ap),
1640 GpuMemoryMutable(x), incx);
1641 }
1642
1643 bool CUDABlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
1644 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1645 const DeviceMemory<std::complex<float>> &ap,
1646 DeviceMemory<std::complex<float>> *x, int incx) {
1647 return DoBlasInternal(cublasCtpsv, stream, true /* = pointer_mode_host */,
1648 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1649 CUDABlasDiagonal(diag), n, GpuComplex(GpuMemory(ap)),
1650 GpuComplex(GpuMemoryMutable(x)), incx);
1651 }
1652
1653 bool CUDABlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
1654 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1655 const DeviceMemory<std::complex<double>> &ap,
1656 DeviceMemory<std::complex<double>> *x, int incx) {
1657 return DoBlasInternal(cublasZtpsv, stream, true /* = pointer_mode_host */,
1658 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1659 CUDABlasDiagonal(diag), n, GpuComplex(GpuMemory(ap)),
1660 GpuComplex(GpuMemoryMutable(x)), incx);
1661 }
1662
1663 bool CUDABlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
1664 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1665 const DeviceMemory<float> &a, int lda,
1666 DeviceMemory<float> *x, int incx) {
1667 return DoBlasInternal(cublasStrmv, stream, true /* = pointer_mode_host */,
1668 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1669 CUDABlasDiagonal(diag), n, GpuMemory(a), lda,
1670 GpuMemoryMutable(x), incx);
1671 }
1672
1673 bool CUDABlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
1674 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1675 const DeviceMemory<double> &a, int lda,
1676 DeviceMemory<double> *x, int incx) {
1677 return DoBlasInternal(cublasDtrmv, stream, true /* = pointer_mode_host */,
1678 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1679 CUDABlasDiagonal(diag), n, GpuMemory(a), lda,
1680 GpuMemoryMutable(x), incx);
1681 }
1682
1683 bool CUDABlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
1684 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1685 const DeviceMemory<std::complex<float>> &a, int lda,
1686 DeviceMemory<std::complex<float>> *x, int incx) {
1687 return DoBlasInternal(cublasCtrmv, stream, true /* = pointer_mode_host */,
1688 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1689 CUDABlasDiagonal(diag), n, GpuComplex(GpuMemory(a)),
1690 lda, GpuComplex(GpuMemoryMutable(x)), incx);
1691 }
1692
1693 bool CUDABlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
1694 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1695 const DeviceMemory<std::complex<double>> &a, int lda,
1696 DeviceMemory<std::complex<double>> *x, int incx) {
1697 return DoBlasInternal(cublasZtrmv, stream, true /* = pointer_mode_host */,
1698 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1699 CUDABlasDiagonal(diag), n, GpuComplex(GpuMemory(a)),
1700 lda, GpuComplex(GpuMemoryMutable(x)), incx);
1701 }
1702
1703 bool CUDABlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1704 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1705 const DeviceMemory<float> &a, int lda,
1706 DeviceMemory<float> *x, int incx) {
1707 return DoBlasInternal(cublasStrsv, stream, true /* = pointer_mode_host */,
1708 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1709 CUDABlasDiagonal(diag), n, GpuMemory(a), lda,
1710 GpuMemoryMutable(x), incx);
1711 }
1712
1713 bool CUDABlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1714 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1715 const DeviceMemory<double> &a, int lda,
1716 DeviceMemory<double> *x, int incx) {
1717 return DoBlasInternal(cublasDtrsv, stream, true /* = pointer_mode_host */,
1718 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1719 CUDABlasDiagonal(diag), n, GpuMemory(a), lda,
1720 GpuMemoryMutable(x), incx);
1721 }
1722
1723 bool CUDABlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1724 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1725 const DeviceMemory<std::complex<float>> &a, int lda,
1726 DeviceMemory<std::complex<float>> *x, int incx) {
1727 return DoBlasInternal(cublasCtrsv, stream, true /* = pointer_mode_host */,
1728 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1729 CUDABlasDiagonal(diag), n, GpuComplex(GpuMemory(a)),
1730 lda, GpuComplex(GpuMemoryMutable(x)), incx);
1731 }
1732
1733 bool CUDABlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1734 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1735 const DeviceMemory<std::complex<double>> &a, int lda,
1736 DeviceMemory<std::complex<double>> *x, int incx) {
1737 return DoBlasInternal(cublasZtrsv, stream, true /* = pointer_mode_host */,
1738 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1739 CUDABlasDiagonal(diag), n, GpuComplex(GpuMemory(a)),
1740 lda, GpuComplex(GpuMemoryMutable(x)), incx);
1741 }
1742
1743 port::Status CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
1744 blas::Transpose transb, uint64 m, uint64 n,
1745 uint64 k, blas::DataType dtype,
1746 const void *alpha, const DeviceMemoryBase &a,
1747 int lda, const DeviceMemoryBase &b, int ldb,
1748 const void *beta, DeviceMemoryBase *c,
1749 int ldc) {
1750 cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
1751
1752 #if CUDA_VERSION < 11000
1753 if (dtype == blas::DataType::kHalf) {
1754 math_type = CUBLAS_TENSOR_OP_MATH;
1755 }
1756 #else
1757 if (dtype == blas::DataType::kFloat) {
1758 math_type = CUBLAS_TF32_TENSOR_OP_MATH;
1759 if (stream->GetCudaComputeCapability().IsAtLeast(
1760 CudaComputeCapability::AMPERE)) {
1761 // TODO(reedwm): Remove or make this VLOG(1) once TensorFloat-32 is more
1762 // well tested.
1763 if (tensorflow::tensor_float_32_execution_enabled()) {
1764 LOG_FIRST_N(INFO, 1) << "TensorFloat-32 will be used for the matrix "
1765 "multiplication. This will only be logged "
1766 "once.";
1767 }
1768 }
1769 }
1770 #endif
1771
1772 // TODO(cheshire): Return an error instead.
1773 // TODO(cheshire): Why are these checked only for `half` and `float`?
1774 if (dtype == blas::DataType::kHalf || dtype == blas::DataType::kFloat) {
1775 if (transa == blas::Transpose::kNoTranspose) {
1776 if (lda < static_cast<int64>(m)) {
1777 LOG(WARNING) << "GEMM lda was smaller than m (no transpose case); "
1778 "precondition violation";
1779 }
1780 } else {
1781 if (lda < static_cast<int64>(k)) {
1782 LOG(WARNING) << "GEMM lda (" << lda << ") was smaller than k (" << k
1783 << ") (transpose case); precondition violation";
1784 }
1785 }
1786 if (transb == blas::Transpose::kNoTranspose) {
1787 if (ldb < static_cast<int64>(k)) {
1788 LOG(WARNING) << "GEMM ldb (" << ldb << ") was smaller than k (" << k
1789 << ") (no transpose case); precondition violation";
1790 }
1791 } else {
1792 if (ldb < static_cast<int64>(n)) {
1793 LOG(WARNING) << "GEMM ldb was smaller than n (transpose case); "
1794 "precondition violation";
1795 }
1796 }
1797 }
1798
1799 VLOG(1) << absl::StrFormat(
1800 "doing cuBLAS SGEMM: at=%d bt=%d m=%u n=%u "
1801 "k=%u alpha=%p a=%p lda=%d b=%p ldb=%d beta=%p "
1802 "c=%p ldc=%d",
1803 static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
1804 a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
1805
1806 switch (dtype) {
1807 case blas::DataType::kHalf: {
1808 #if CUDA_VERSION < 7050
1809 return port::InternalError(
1810 "fp16 sgemm is not implemented in this cuBLAS version "
1811 "(need at least CUDA 7.5)");
1812 #endif
1813
1814 return DoBlasInternalImpl(
1815 cublasSgemmEx, stream, true /* = pointer_mode_host */, math_type,
1816 CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
1817 static_cast<const float *>(alpha), a.opaque(), SE_CUDA_DATA_HALF, lda,
1818 b.opaque(), SE_CUDA_DATA_HALF, ldb, static_cast<const float *>(beta),
1819 c->opaque(), SE_CUDA_DATA_HALF, ldc);
1820 }
1821 #if CUDA_VERSION > 11000
1822 case blas::DataType::kBF16: {
1823 return DoBlasInternalImpl(
1824 cublasSgemmEx, stream, true /* = pointer_mode_host */, math_type,
1825 CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
1826 static_cast<const float *>(alpha), a.opaque(), CUDA_R_16BF, lda,
1827 b.opaque(), CUDA_R_16BF, ldb, static_cast<const float *>(beta),
1828 c->opaque(), CUDA_R_16BF, ldc);
1829 }
1830 #endif
1831 case dnn::kFloat:
1832 return DoBlasInternalImpl(
1833 cublasSgemm, stream, true /* = pointer_mode_host */, math_type,
1834 CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
1835 static_cast<const float *>(alpha),
1836 static_cast<const float *>(a.opaque()), lda,
1837 static_cast<const float *>(b.opaque()), ldb,
1838 static_cast<const float *>(beta), static_cast<float *>(c->opaque()),
1839 ldc);
1840 case dnn::kDouble:
1841 return DoBlasInternalImpl(
1842 cublasDgemm, stream, true /* = pointer_mode_host */, math_type,
1843 CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
1844 static_cast<const double *>(alpha),
1845 static_cast<const double *>(a.opaque()), lda,
1846 static_cast<const double *>(b.opaque()), ldb,
1847 static_cast<const double *>(beta), static_cast<double *>(c->opaque()),
1848 ldc);
1849 case dnn::kComplexFloat: {
1850 GpuComplexType cb_alpha =
1851 GpuComplexValue(*static_cast<const std::complex<float> *>(alpha));
1852 GpuComplexType cb_beta =
1853 GpuComplexValue(*static_cast<const std::complex<float> *>(beta));
1854 return DoBlasInternalImpl(
1855 cublasCgemm, stream, true /* = pointer_mode_host */, math_type,
1856 CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
1857 &cb_alpha, static_cast<const GpuComplexType *>(a.opaque()), lda,
1858 static_cast<const GpuComplexType *>(b.opaque()), ldb, &cb_beta,
1859 static_cast<GpuComplexType *>(c->opaque()), ldc);
1860 }
1861 case dnn::kComplexDouble: {
1862 GpuDoubleComplexType cb_alpha =
1863 GpuComplexValue(*static_cast<const std::complex<double> *>(alpha));
1864 GpuDoubleComplexType cb_beta =
1865 GpuComplexValue(*static_cast<const std::complex<double> *>(beta));
1866 return DoBlasInternalImpl(
1867 cublasZgemm, stream, true /* = pointer_mode_host */, math_type,
1868 CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
1869 &cb_alpha, static_cast<const GpuDoubleComplexType *>(a.opaque()), lda,
1870 static_cast<const GpuDoubleComplexType *>(b.opaque()), ldb, &cb_beta,
1871 static_cast<GpuDoubleComplexType *>(c->opaque()), ldc);
1872 }
1873 default:
1874 return port::InternalError(absl::StrCat("Unsupported datatype for GEMM: ",
1875 blas::DataTypeString(dtype)));
1876 }
1877 }
1878
1879 bool CUDABlas::DoBlasGemvWithProfiling(
1880 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha,
1881 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
1882 int incx, float beta, DeviceMemory<float> *y, int incy,
1883 blas::ProfileResult *output_profile_result) {
1884 return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
1885 incx, beta, y, incy,
1886 output_profile_result);
1887 }
1888
1889 bool CUDABlas::DoBlasGemvWithProfiling(
1890 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha,
1891 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
1892 int incx, double beta, DeviceMemory<double> *y, int incy,
1893 blas::ProfileResult *output_profile_result) {
1894 return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
1895 incx, beta, y, incy,
1896 output_profile_result);
1897 }
1898
1899 bool CUDABlas::DoBlasGemvWithProfiling(
1900 Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
1901 std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a,
1902 int lda, const DeviceMemory<std::complex<float>> &x, int incx,
1903 std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
1904 blas::ProfileResult *output_profile_result) {
1905 return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
1906 incx, beta, y, incy,
1907 output_profile_result);
1908 }
1909
1910 bool CUDABlas::DoBlasGemvWithProfiling(
1911 Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
1912 std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a,
1913 int lda, const DeviceMemory<std::complex<double>> &x, int incx,
1914 std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy,
1915 blas::ProfileResult *output_profile_result) {
1916 return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
1917 incx, beta, y, incy,
1918 output_profile_result);
1919 }
1920
1921 bool CUDABlas::DoBlasGemmWithProfiling(
1922 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1923 uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
1924 int lda, const DeviceMemory<Eigen::half> &b, int ldb, float beta,
1925 DeviceMemory<Eigen::half> *c, int ldc,
1926 blas::ProfileResult *output_profile_result) {
1927 return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1928 lda, b, ldb, beta, c, ldc,
1929 output_profile_result);
1930 }
1931
1932 bool CUDABlas::DoBlasGemmWithProfiling(
1933 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1934 uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
1935 const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
1936 int ldc, blas::ProfileResult *output_profile_result) {
1937 return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1938 lda, b, ldb, beta, c, ldc,
1939 output_profile_result);
1940 }
1941
1942 bool CUDABlas::DoBlasGemmWithProfiling(
1943 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1944 uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
1945 const DeviceMemory<double> &b, int ldb, double beta,
1946 DeviceMemory<double> *c, int ldc,
1947 blas::ProfileResult *output_profile_result) {
1948 return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1949 lda, b, ldb, beta, c, ldc,
1950 output_profile_result);
1951 }
1952
1953 bool CUDABlas::DoBlasGemmWithProfiling(
1954 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1955 uint64 n, uint64 k, std::complex<float> alpha,
1956 const DeviceMemory<std::complex<float>> &a, int lda,
1957 const DeviceMemory<std::complex<float>> &b, int ldb,
1958 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
1959 blas::ProfileResult *output_profile_result) {
1960 return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1961 lda, b, ldb, beta, c, ldc,
1962 output_profile_result);
1963 }
1964
1965 bool CUDABlas::DoBlasGemmWithProfiling(
1966 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1967 uint64 n, uint64 k, std::complex<double> alpha,
1968 const DeviceMemory<std::complex<double>> &a, int lda,
1969 const DeviceMemory<std::complex<double>> &b, int ldb,
1970 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
1971 blas::ProfileResult *output_profile_result) {
1972 return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1973 lda, b, ldb, beta, c, ldc,
1974 output_profile_result);
1975 }
1976
1977 template <typename T>
1978 bool CUDABlas::DoBlasGemvWithProfilingImpl(
1979 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, const T &alpha,
1980 const DeviceMemory<T> &a, int lda, const DeviceMemory<T> &x, int incx,
1981 const T &beta, DeviceMemory<T> *y, int incy,
1982 blas::ProfileResult *output_profile_result) {
1983 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
1984 if (output_profile_result != nullptr) {
1985 timer.reset(new GpuTimer(parent_));
1986 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
1987 return false;
1988 }
1989 }
1990
1991 // Call blasGemm
1992 bool result =
1993 DoBlasGemv(stream, trans, m, n, alpha, a, lda, x, incx, beta, y, incy);
1994
1995 if (timer != nullptr && result) {
1996 // GpuTimer will CHECK-fail if we Stop() it while the stream is in an error
1997 // state.
1998 if (!timer->Stop(AsGpuStream(stream))) {
1999 return false;
2000 }
2001 output_profile_result->set_is_valid(true);
2002 output_profile_result->set_algorithm(blas::kDefaultBlasGemv);
2003 output_profile_result->set_elapsed_time_in_ms(
2004 timer->GetElapsedMilliseconds());
2005 }
2006 return result;
2007 }
2008
2009 template <typename T, typename ParamType>
2010 bool CUDABlas::DoBlasGemmWithProfilingImpl(
2011 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2012 uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a,
2013 int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta,
2014 DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result) {
2015 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
2016 if (output_profile_result != nullptr) {
2017 timer.reset(new GpuTimer(parent_));
2018 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
2019 return false;
2020 }
2021 }
2022
2023 // Call blasGemm
2024 bool result =
2025 DoBlasGemm(stream, transa, transb, m, n, k, blas::ToDataType<T>::value,
2026 &alpha, a, lda, b, ldb, &beta, c, ldc)
2027 .ok();
2028
2029 if (timer != nullptr && result) {
2030 // GpuTimer will CHECK-fail if we Stop() it while the stream is in an error
2031 // state.
2032 if (!timer->Stop(AsGpuStream(stream))) {
2033 return false;
2034 }
2035 output_profile_result->set_is_valid(true);
2036 output_profile_result->set_algorithm(blas::kDefaultBlasGemm);
2037 output_profile_result->set_elapsed_time_in_ms(
2038 timer->GetElapsedMilliseconds());
2039 }
2040 return result;
2041 }
2042
2043 static bool UsesTensorOps(blas::AlgorithmType algo) {
2044 #if CUDA_VERSION >= 9000
2045 cublasGemmAlgo_t cublas_algo = static_cast<cublasGemmAlgo_t>(algo);
2046 return cublas_algo >= CUBLAS_GEMM_DEFAULT_TENSOR_OP;
2047 #else
2048 return false;
2049 #endif
2050 }
2051
2052 static port::StatusOr<cublasMath_t> GetMathTypeForGemmEx(
2053 Stream *stream, blas::AlgorithmType algorithm, blas::DataType type_a,
2054 blas::DataType type_b) {
2055 if (type_a != type_b) {
2056 return port::InternalError("Types of inputs mismatch");
2057 }
2058
2059 // GPUs < sm_50 don't support cublasGemmEx.
2060 CudaComputeCapability cc = stream->GetCudaComputeCapability();
2061 if (cc.major < 5) {
2062 return port::InternalError(absl::StrCat(
2063 "sm_", cc.major, " does not support explicit gemm algorithms."));
2064 }
2065
2066 bool algo_uses_tensor_ops = UsesTensorOps(algorithm);
2067 cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
2068 if (algo_uses_tensor_ops) {
2069 if (cc.major < 7) {
2070 return port::InternalError(absl::StrCat(
2071 "Algorithm ", algorithm,
2072 " uses tensor ops, but tensor ops are not available in sm", cc.major,
2073 "X devices."));
2074 } else if (type_a == blas::DataType::kFloat) {
2075 #if CUDA_VERSION < 11000
2076 return port::InternalError(absl::StrCat(
2077 "Algorithm ", algorithm,
2078 " uses tensor ops, but tensor ops are not available for fp32"));
2079 #else
2080 if (cc.major < 8) {
2081 return port::InternalError(absl::StrCat(
2082 "Algorithm ", algorithm,
2083 " uses tensor ops, but tensor ops are not available in sm",
2084 cc.major, "X devices for float input types."));
2085 } else if (!tensorflow::tensor_float_32_execution_enabled()) {
2086 return port::InternalError(absl::StrCat(
2087 "Algorithm ", algorithm,
2088 " uses tensor ops, but tensor ops are disabled for fp32 inputs"));
2089 }
2090 math_type = CUBLAS_TF32_TENSOR_OP_MATH;
2091 #endif
2092 } else if (type_a == blas::DataType::kHalf) {
2093 #if CUDA_VERSION < 11000
2094 math_type = CUBLAS_TENSOR_OP_MATH;
2095 #endif
2096 } else {
2097 return port::InternalError(
2098 absl::StrCat("Algorithm ", algorithm,
2099 " uses tensor ops which are not supported for input"));
2100 }
2101 }
2102
2103 // Return false if we might be hitting a cuBLAS bug that produces the wrong
2104 // result. See nvbugs/2156201, b/79126339.
2105 #if CUDA_VERSION >= 9000 && CUDA_VERSION < 9020
2106 if ((algorithm == CUBLAS_GEMM_DEFAULT || algorithm >= CUBLAS_GEMM_ALGO13) &&
2107 std::max({m, n, k}) >= 2097153 && cc_major < 7) {
2108 return port::InternalError(
2109 "DoBlasGemmWithAlgorithm returning false to work around cudnn "
2110 "<9.2 bug with m, n, or k >= 2097153. See b/79126339.");
2111 }
2112 #endif
2113 return math_type;
2114 }
2115
2116 static port::StatusOr<std::unique_ptr<GpuTimer, GpuTimerDeleter>>
2117 StartGpuTimerForProfile(Stream *stream, GpuExecutor *executor,
2118 blas::ProfileResult *output_profile_result) {
2119 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
2120 if (output_profile_result) {
2121 timer.reset(new GpuTimer(executor));
2122 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
2123 return port::InternalError(
2124 "output_profile_result given, but unable to create a GpuTimer");
2125 }
2126 }
2127 return timer;
2128 }
2129
2130 static port::Status PopulateProfileFromTimer(
2131 GpuTimer *timer, blas::AlgorithmType algorithm,
2132 blas::ProfileResult *output_profile_result, Stream *stream) {
2133 if (timer) {
2134 // GpuTimer will CHECK-fail if we Stop() it while the stream is in an error
2135 // state.
2136 if (!timer->Stop(AsGpuStream(stream))) {
2137 return port::InternalError("unable to stop GpuTimer.");
2138 }
2139 output_profile_result->set_is_valid(true);
2140 output_profile_result->set_algorithm(algorithm);
2141 output_profile_result->set_elapsed_time_in_ms(
2142 timer->GetElapsedMilliseconds());
2143 }
2144 return port::Status::OK();
2145 }
2146
2147 port::Status CUDABlas::DoBlasGemmWithAlgorithm(
2148 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2149 uint64 n, uint64 k, const void *alpha, const DeviceMemoryBase &a,
2150 blas::DataType type_a, int lda, const DeviceMemoryBase &b,
2151 blas::DataType type_b, int ldb, const void *beta, DeviceMemoryBase *c,
2152 blas::DataType type_c, int ldc, blas::ComputationType computation_type,
2153 blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
2154 TF_ASSIGN_OR_RETURN(cublasMath_t math_type,
2155 GetMathTypeForGemmEx(stream, algorithm, type_a, type_b));
2156
2157 TF_ASSIGN_OR_RETURN(auto timer, StartGpuTimerForProfile(
2158 stream, parent_, output_profile_result));
2159
2160 // Since we are converting 'algorithm' to cublasGemmAlgo_t by static_cast,
2161 // we do the following compile-time check on the default value:
2162 static_assert(blas::kDefaultGemmAlgo == CUBLAS_GEMM_DFALT, "");
2163
2164 TF_RETURN_IF_ERROR(DoBlasInternalImpl(
2165 AS_LAMBDA(cublasGemmEx), stream, /*pointer_mode_host=*/true, math_type,
2166 CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, alpha,
2167 a.opaque(), GetCUDADataType(type_a), lda, b.opaque(),
2168 GetCUDADataType(type_b), ldb, beta, c->opaque(), GetCUDADataType(type_c),
2169 ldc, CUDAComputationType(computation_type),
2170 static_cast<cublasGemmAlgo_t>(algorithm)));
2171 TF_RETURN_IF_ERROR(PopulateProfileFromTimer(timer.get(), algorithm,
2172 output_profile_result, stream));
2173 return port::Status::OK();
2174 }
2175
2176 port::Status CUDABlas::DoBlasGemmStridedBatchedWithAlgorithm(
2177 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2178 uint64 n, uint64 k, const void *alpha, const DeviceMemoryBase &a,
2179 blas::DataType type_a, int lda, int64_t stride_a, const DeviceMemoryBase &b,
2180 blas::DataType type_b, int ldb, int64_t stride_b, const void *beta,
2181 DeviceMemoryBase *c, blas::DataType type_c, int ldc, int64_t stride_c,
2182 int batch_count, blas::ComputationType computation_type,
2183 blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
2184 TF_ASSIGN_OR_RETURN(cublasMath_t math_type,
2185 GetMathTypeForGemmEx(stream, algorithm, type_a, type_b));
2186 TF_ASSIGN_OR_RETURN(auto timer, StartGpuTimerForProfile(
2187 stream, parent_, output_profile_result));
2188
2189 cudaDataType_t cuda_in_type = GetCUDADataType(type_a);
2190
2191 #if CUDA_VERSION >= 11000
2192 // Workaround CUDA bug where batched GEMM is erroneously marked as
2193 // unsupported by manually unbatching it on Pascal.
2194 if (cuda_in_type == CUDA_R_16BF &&
2195 !stream->GetCudaComputeCapability().IsAtLeast(7)) {
2196 for (int batch = 0; batch < batch_count; ++batch) {
2197 const auto *a_matrix = reinterpret_cast<const __nv_bfloat16 *>(
2198 static_cast<const Eigen::bfloat16 *>(a.opaque()) + batch * stride_a);
2199 const auto *b_matrix = reinterpret_cast<const __nv_bfloat16 *>(
2200 static_cast<const Eigen::bfloat16 *>(b.opaque()) + batch * stride_b);
2201 auto *c_matrix = reinterpret_cast<__nv_bfloat16 *>(
2202 static_cast<Eigen::bfloat16 *>(c->opaque()) + batch * stride_c);
2203 TF_RETURN_IF_ERROR(DoBlasInternalImpl(
2204 AS_LAMBDA(cublasGemmEx), stream, /*pointer_mode_host=*/true,
2205 math_type, CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n,
2206 k, static_cast<const float *>(alpha), a_matrix, CUDA_R_16BF, lda,
2207 b_matrix, CUDA_R_16BF, ldb, static_cast<const float *>(beta),
2208 c_matrix, CUDA_R_16BF, ldc, CUDAComputationType(computation_type),
2209 static_cast<cublasGemmAlgo_t>(algorithm)));
2210 }
2211 TF_RETURN_IF_ERROR(PopulateProfileFromTimer(timer.get(), algorithm,
2212 output_profile_result, stream));
2213 return port::Status::OK();
2214 }
2215 #endif
2216
2217 TF_RETURN_IF_ERROR(DoBlasInternalImpl(
2218 AS_LAMBDA(cublasGemmStridedBatchedEx), stream, /*pointer_mode_host=*/true,
2219 math_type, CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
2220 alpha, a.opaque(), cuda_in_type, lda, stride_a, b.opaque(), cuda_in_type,
2221 ldb, stride_b, beta, c->opaque(), GetCUDADataType(type_c), ldc, stride_c,
2222 batch_count, CUDAComputationType(computation_type),
2223 static_cast<cublasGemmAlgo_t>(algorithm)));
2224 TF_RETURN_IF_ERROR(PopulateProfileFromTimer(timer.get(), algorithm,
2225 output_profile_result, stream));
2226 return port::Status::OK();
2227 }
2228
2229 bool CUDABlas::GetBlasGemmAlgorithms(
2230 std::vector<blas::AlgorithmType> *out_algorithms) {
2231 // cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx)
2232 // were first introduced in CUDA 8.
2233 //
2234 // Note that when CUDA version and compute capability is not sufficient, we
2235 // still return the out_algorithms. Caller needs to make sure that in this
2236 // case, the returned vector is empty.
2237 *out_algorithms = {
2238 CUBLAS_GEMM_DFALT,
2239 CUBLAS_GEMM_ALGO0,
2240 CUBLAS_GEMM_ALGO1,
2241 CUBLAS_GEMM_ALGO2,
2242 CUBLAS_GEMM_ALGO3,
2243 CUBLAS_GEMM_ALGO4,
2244 CUBLAS_GEMM_ALGO5,
2245 CUBLAS_GEMM_ALGO6,
2246 CUBLAS_GEMM_ALGO7,
2247 #if CUDA_VERSION >= 9000
2248 CUBLAS_GEMM_ALGO8,
2249 CUBLAS_GEMM_ALGO9,
2250 CUBLAS_GEMM_ALGO10,
2251 CUBLAS_GEMM_ALGO11,
2252 CUBLAS_GEMM_ALGO12,
2253 CUBLAS_GEMM_ALGO13,
2254 CUBLAS_GEMM_ALGO14,
2255 CUBLAS_GEMM_ALGO15,
2256 CUBLAS_GEMM_ALGO16,
2257 CUBLAS_GEMM_ALGO17,
2258 CUBLAS_GEMM_DFALT_TENSOR_OP,
2259 CUBLAS_GEMM_ALGO0_TENSOR_OP,
2260 CUBLAS_GEMM_ALGO1_TENSOR_OP,
2261 CUBLAS_GEMM_ALGO2_TENSOR_OP,
2262 CUBLAS_GEMM_ALGO3_TENSOR_OP,
2263 CUBLAS_GEMM_ALGO4_TENSOR_OP,
2264 #endif
2265 #if CUDA_VERSION >= 9020
2266 CUBLAS_GEMM_ALGO18,
2267 CUBLAS_GEMM_ALGO19,
2268 CUBLAS_GEMM_ALGO20,
2269 CUBLAS_GEMM_ALGO21,
2270 CUBLAS_GEMM_ALGO22,
2271 CUBLAS_GEMM_ALGO23,
2272 CUBLAS_GEMM_ALGO5_TENSOR_OP,
2273 CUBLAS_GEMM_ALGO6_TENSOR_OP,
2274 CUBLAS_GEMM_ALGO7_TENSOR_OP,
2275 CUBLAS_GEMM_ALGO8_TENSOR_OP,
2276 CUBLAS_GEMM_ALGO9_TENSOR_OP,
2277 CUBLAS_GEMM_ALGO10_TENSOR_OP,
2278 CUBLAS_GEMM_ALGO11_TENSOR_OP,
2279 CUBLAS_GEMM_ALGO12_TENSOR_OP,
2280 CUBLAS_GEMM_ALGO13_TENSOR_OP,
2281 CUBLAS_GEMM_ALGO14_TENSOR_OP,
2282 CUBLAS_GEMM_ALGO15_TENSOR_OP,
2283 #endif
2284 };
2285 return true;
2286 }
2287
2288 template <typename T>
2289 struct HalfAsFloat {
2290 typedef T type;
2291 };
2292
2293 template <>
2294 struct HalfAsFloat<Eigen::half> {
2295 typedef float type;
2296 };
2297
2298 namespace {
2299 // pass-through for non-complex types that don't need conversion to
2300 // cublas-specific type.
2301 template <typename T>
2302 T inline GpuComplexValue(T v) {
2303 return v;
2304 }
2305 } // namespace
2306
2307 template <typename T, typename Scalar, typename FuncT>
2308 port::Status CUDABlas::DoBlasGemmBatchedInternal(
2309 FuncT cublas_func, Stream *stream, blas::Transpose transa,
2310 blas::Transpose transb, uint64 m, uint64 n, uint64 k, Scalar alpha,
2311 const port::ArraySlice<DeviceMemory<T> *> &a_ptrs_to_wrappers, int lda,
2312 const port::ArraySlice<DeviceMemory<T> *> &b_ptrs_to_wrappers, int ldb,
2313 Scalar beta, const port::ArraySlice<DeviceMemory<T> *> &c_ptrs_to_wrappers,
2314 int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
2315 std::vector<T *> a_raw_ptrs, b_raw_ptrs, c_raw_ptrs;
2316 for (int i = 0; i < batch_count; ++i) {
2317 a_raw_ptrs.push_back(static_cast<T *>(a_ptrs_to_wrappers[i]->opaque()));
2318 b_raw_ptrs.push_back(static_cast<T *>(b_ptrs_to_wrappers[i]->opaque()));
2319 c_raw_ptrs.push_back(static_cast<T *>(c_ptrs_to_wrappers[i]->opaque()));
2320 }
2321
2322 typedef typename HalfAsFloat<typename GpuComplexT<T>::type>::type CUDA_T;
2323
2324 const size_t size = batch_count * sizeof(CUDA_T *);
2325
2326 // Device-side copy of pointers to matrices.
2327 DeviceMemory<CUDA_T *> a;
2328 DeviceMemory<CUDA_T *> b;
2329 DeviceMemory<CUDA_T *> c;
2330
2331 // If temporary space is allocated for device-side copies of pointers to
2332 // matrices, that temporary space should not be freed until this function
2333 // returns. Although the values for these unique_ptrs are not set here, they
2334 // are declared at this scope so they will be destroyed when the function
2335 // returns.
2336 //
2337 // If a scratch allocator is provided, these pointers will not be used at all.
2338 std::unique_ptr<TemporaryDeviceMemory<CUDA_T *>> a_temporary;
2339 std::unique_ptr<TemporaryDeviceMemory<CUDA_T *>> b_temporary;
2340 std::unique_ptr<TemporaryDeviceMemory<CUDA_T *>> c_temporary;
2341
2342 // Decide how to allocate device-side copy of pointers to matrices based on
2343 // whether a scratch allocator was passed.
2344 if (scratch_allocator != nullptr) {
2345 SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> a_bytes,
2346 scratch_allocator->AllocateBytes(size));
2347 SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> b_bytes,
2348 scratch_allocator->AllocateBytes(size));
2349 SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> c_bytes,
2350 scratch_allocator->AllocateBytes(size));
2351 a = DeviceMemory<CUDA_T *>(a_bytes);
2352 b = DeviceMemory<CUDA_T *>(b_bytes);
2353 c = DeviceMemory<CUDA_T *>(c_bytes);
2354 } else {
2355 SE_ASSIGN_OR_RETURN(a_temporary,
2356 stream->AllocateTemporaryArray<CUDA_T *>(batch_count));
2357 SE_ASSIGN_OR_RETURN(b_temporary,
2358 stream->AllocateTemporaryArray<CUDA_T *>(batch_count));
2359 SE_ASSIGN_OR_RETURN(c_temporary,
2360 stream->AllocateTemporaryArray<CUDA_T *>(batch_count));
2361 a = DeviceMemory<CUDA_T *>(*a_temporary->mutable_device_memory());
2362 b = DeviceMemory<CUDA_T *>(*b_temporary->mutable_device_memory());
2363 c = DeviceMemory<CUDA_T *>(*c_temporary->mutable_device_memory());
2364 }
2365
2366 if (!stream->ThenMemcpy(&a, a_raw_ptrs.data(), size).ok() ||
2367 !stream->ThenMemcpy(&b, b_raw_ptrs.data(), size).ok() ||
2368 !stream->ThenMemcpy(&c, c_raw_ptrs.data(), size).ok()) {
2369 return port::Status(port::error::INTERNAL,
2370 "failed to copy memory from host to device in "
2371 "CUDABlas::DoBlasGemmBatched");
2372 }
2373
2374 cudaDataType_t data_type = CUDADataType<T>::type;
2375
2376 #if CUDA_VERSION >= 9010
2377 if (stream->GetCudaComputeCapability().IsAtLeast(5)) {
2378 cublasMath_t math_type;
2379 cublasGemmAlgo_t algo;
2380 if (data_type == CUDA_R_16F) {
2381 #if CUDA_VERSION < 11000
2382 math_type = CUBLAS_TENSOR_OP_MATH;
2383 #else
2384 math_type = CUBLAS_DEFAULT_MATH;
2385 #endif
2386 algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
2387 #if CUBLAS_VER_MAJOR >= 11
2388 } else if (data_type == CUDA_R_32F) {
2389 // DoBlassInternalImpl will switch math_type back to CUBLAS_DEFAULT_MATH
2390 // if TensorFloat-32 is disabled.
2391 math_type = CUBLAS_TF32_TENSOR_OP_MATH;
2392 algo = tensorflow::tensor_float_32_execution_enabled()
2393 ? CUBLAS_GEMM_DFALT_TENSOR_OP
2394 : CUBLAS_GEMM_DFALT;
2395 #endif
2396 } else {
2397 math_type = CUBLAS_DEFAULT_MATH;
2398 algo = CUBLAS_GEMM_DFALT;
2399 }
2400 cudaDataType_t compute_type =
2401 (data_type == CUDA_R_16F ? CUDA_R_32F : data_type);
2402 const void **a_void_ptrs = reinterpret_cast<const void **>(
2403 const_cast<const CUDA_T **>(GpuMemory(a)));
2404 const void **b_void_ptrs = reinterpret_cast<const void **>(
2405 const_cast<const CUDA_T **>(GpuMemory(b)));
2406 void **c_void_ptrs =
2407 reinterpret_cast<void **>(const_cast<CUDA_T **>(GpuMemory(c)));
2408 return DoBlasInternalImpl(
2409 AS_LAMBDA(cublasGemmBatchedEx), stream, true /* = pointer_mode_host */,
2410 math_type, CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n,
2411 k, &alpha, a_void_ptrs, data_type, lda, b_void_ptrs, data_type, ldb,
2412 &beta, c_void_ptrs, data_type, ldc, batch_count, compute_type, algo);
2413 }
2414 #endif
2415 // either CUDA_VERSION < 9.1 or SM < 5.0
2416 if (data_type != CUDA_R_16F) {
2417 auto cb_alpha = GpuComplexValue(alpha);
2418 auto cb_beta = GpuComplexValue(beta);
2419 bool ok = DoBlasInternal(
2420 cublas_func, stream, true /* = pointer_mode_host */,
2421 CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
2422 GpuComplex(&cb_alpha), const_cast<const CUDA_T **>(GpuMemory(a)), lda,
2423 const_cast<const CUDA_T **>(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
2424 const_cast<CUDA_T **>(GpuMemory(c)), ldc, batch_count);
2425 if (ok) {
2426 return port::Status::OK();
2427 }
2428 return port::Status(port::error::INTERNAL,
2429 "failed BLAS call, see log for details");
2430 } else {
2431 // Fall back to a loop for fp16
2432 for (int b = 0; b < batch_count; ++b) {
2433 const DeviceMemory<T> &a_matrix = *a_ptrs_to_wrappers[b];
2434 const DeviceMemory<T> &b_matrix = *b_ptrs_to_wrappers[b];
2435 DeviceMemory<T> *c_matrix = c_ptrs_to_wrappers[b];
2436 TF_RETURN_IF_ERROR(DoBlasGemm(
2437 stream, transa, transb, m, n, k, blas::ToDataType<T>::value, &alpha,
2438 a_matrix, lda, b_matrix, ldb, &beta, c_matrix, ldc));
2439 }
2440 return port::Status::OK();
2441 }
2442 }
2443
2444 bool CUDABlas::DoBlasGemmBatched(
2445 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2446 uint64 n, uint64 k, float alpha,
2447 const port::ArraySlice<DeviceMemory<Eigen::half> *> &a_array, int lda,
2448 const port::ArraySlice<DeviceMemory<Eigen::half> *> &b_array, int ldb,
2449 float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c_array,
2450 int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
2451 // Note: The func passed here (cublasSgemmBatched) is not actually called,
2452 // due to special handling of fp16 inside DoBlasGemmBatchedInternal.
2453 port::Status status = DoBlasGemmBatchedInternal(
2454 cublasSgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
2455 b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
2456 if (!status.ok()) {
2457 LOG(ERROR) << status;
2458 }
2459 return status.ok();
2460 }
2461
2462 bool CUDABlas::DoBlasGemmBatched(
2463 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2464 uint64 n, uint64 k, float alpha,
2465 const port::ArraySlice<DeviceMemory<float> *> &a_array, int lda,
2466 const port::ArraySlice<DeviceMemory<float> *> &b_array, int ldb, float beta,
2467 const port::ArraySlice<DeviceMemory<float> *> &c_array, int ldc,
2468 int batch_count, ScratchAllocator *scratch_allocator) {
2469 port::Status status = DoBlasGemmBatchedInternal(
2470 cublasSgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
2471 b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
2472 if (!status.ok()) {
2473 LOG(ERROR) << status;
2474 }
2475 return status.ok();
2476 }
2477
2478 bool CUDABlas::DoBlasGemmBatched(
2479 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2480 uint64 n, uint64 k, double alpha,
2481 const port::ArraySlice<DeviceMemory<double> *> &a_array, int lda,
2482 const port::ArraySlice<DeviceMemory<double> *> &b_array, int ldb,
2483 double beta, const port::ArraySlice<DeviceMemory<double> *> &c_array,
2484 int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
2485 port::Status status = DoBlasGemmBatchedInternal(
2486 cublasDgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
2487 b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
2488 if (!status.ok()) {
2489 LOG(ERROR) << status;
2490 }
2491 return status.ok();
2492 }
2493
2494 bool CUDABlas::DoBlasGemmBatched(
2495 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2496 uint64 n, uint64 k, std::complex<float> alpha,
2497 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a_array,
2498 int lda,
2499 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b_array,
2500 int ldb, std::complex<float> beta,
2501 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c_array,
2502 int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
2503 port::Status status = DoBlasGemmBatchedInternal(
2504 cublasCgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
2505 b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
2506 if (!status.ok()) {
2507 LOG(ERROR) << status;
2508 }
2509 return status.ok();
2510 }
2511
2512 bool CUDABlas::DoBlasGemmBatched(
2513 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2514 uint64 n, uint64 k, std::complex<double> alpha,
2515 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a_array,
2516 int lda,
2517 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b_array,
2518 int ldb, std::complex<double> beta,
2519 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c_array,
2520 int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
2521 port::Status status = DoBlasGemmBatchedInternal(
2522 cublasZgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
2523 b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
2524 if (!status.ok()) {
2525 LOG(ERROR) << status;
2526 }
2527 return status.ok();
2528 }
2529
2530 port::Status CUDABlas::DoBlasGemmStridedBatched(
2531 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2532 uint64 n, uint64 k, blas::DataType dtype, const void *alpha,
2533 const DeviceMemoryBase &a, int lda, int64_t stride_a,
2534 const DeviceMemoryBase &b, int ldb, int64_t stride_b, const void *beta,
2535 DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count) {
2536 cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
2537 #if CUDA_VERSION < 11000
2538 if (dtype == dnn::kHalf) {
2539 math_type = CUBLAS_TENSOR_OP_MATH;
2540 }
2541 #else
2542 if (dtype == dnn::kFloat) {
2543 math_type = CUBLAS_TF32_TENSOR_OP_MATH;
2544 }
2545 #endif
2546
2547 switch (dtype) {
2548 #if CUDA_VERSION >= 11000
2549 case dnn::kBF16: {
2550 CudaComputeCapability cc = stream->GetCudaComputeCapability();
2551 if (cc.IsAtLeast(7)) {
2552 cublasGemmAlgo_t algo =
2553 (cc.major >= 7 ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT);
2554 return DoBlasInternalImpl(
2555 AS_LAMBDA(cublasGemmStridedBatchedEx), stream,
2556 true /* = pointer_mode_host */, math_type,
2557 CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
2558 alpha, a.opaque(), CUDA_R_16BF, lda, stride_a, b.opaque(),
2559 CUDA_R_16BF, ldb, stride_b, beta, c->opaque(), CUDA_R_16BF, ldc,
2560 stride_c, batch_count,
2561 /*compute_type=*/CUDA_R_32F, algo);
2562 }
2563 // Fall back to a loop.
2564 for (int batch = 0; batch < batch_count; ++batch) {
2565 const auto *a_matrix = reinterpret_cast<const __nv_bfloat16 *>(
2566 static_cast<const Eigen::bfloat16 *>(a.opaque()) +
2567 batch * stride_a);
2568 const auto *b_matrix = reinterpret_cast<const __nv_bfloat16 *>(
2569 static_cast<const Eigen::bfloat16 *>(b.opaque()) +
2570 batch * stride_b);
2571 auto *c_matrix = reinterpret_cast<__nv_bfloat16 *>(
2572 static_cast<Eigen::bfloat16 *>(c->opaque()) + batch * stride_c);
2573 TF_RETURN_IF_ERROR(DoBlasInternalImpl(
2574 cublasSgemmEx, stream, true /* = pointer_mode_host */,
2575 CUBLAS_DEFAULT_MATH, CUDABlasTranspose(transa),
2576 CUDABlasTranspose(transb), m, n, k,
2577 static_cast<const float *>(alpha), a_matrix, CUDA_R_16BF, lda,
2578 b_matrix, CUDA_R_16BF, ldb, static_cast<const float *>(beta),
2579 c_matrix, CUDA_R_16BF, ldc));
2580 }
2581 return port::Status::OK();
2582 }
2583 #endif
2584 case dnn::kHalf: {
2585 #if CUDA_VERSION >= 9010
2586 CudaComputeCapability cc = stream->GetCudaComputeCapability();
2587 if (cc.major >= 5) {
2588 cublasGemmAlgo_t algo =
2589 (cc.major >= 7 ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT);
2590 return DoBlasInternalImpl(
2591 AS_LAMBDA(cublasGemmStridedBatchedEx), stream,
2592 true /* = pointer_mode_host */, math_type,
2593 CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
2594 alpha, a.opaque(), CUDA_R_16F, lda, stride_a, b.opaque(),
2595 CUDA_R_16F, ldb, stride_b, beta, c->opaque(), CUDA_R_16F, ldc,
2596 stride_c, batch_count, CUDA_R_32F, algo);
2597 }
2598 #endif
2599 // Either CUDA_VERSION < 9.1 or SM < 5.0. Fall back to a loop.
2600 for (int batch = 0; batch < batch_count; ++batch) {
2601 const auto *a_matrix = reinterpret_cast<const __half *>(
2602 static_cast<const Eigen::half *>(a.opaque()) + batch * stride_a);
2603 const auto *b_matrix = reinterpret_cast<const __half *>(
2604 static_cast<const Eigen::half *>(b.opaque()) + batch * stride_b);
2605 auto *c_matrix = reinterpret_cast<__half *>(
2606 static_cast<Eigen::half *>(c->opaque()) + batch * stride_c);
2607 TF_RETURN_IF_ERROR(DoBlasInternalImpl(
2608 cublasSgemmEx, stream, true /* = pointer_mode_host */,
2609 CUBLAS_DEFAULT_MATH, CUDABlasTranspose(transa),
2610 CUDABlasTranspose(transb), m, n, k,
2611 static_cast<const float *>(alpha), a_matrix, SE_CUDA_DATA_HALF, lda,
2612 b_matrix, SE_CUDA_DATA_HALF, ldb, static_cast<const float *>(beta),
2613 c_matrix, SE_CUDA_DATA_HALF, ldc));
2614 }
2615 return port::Status::OK();
2616 }
2617 case dnn::kFloat: {
2618 return DoBlasInternalImpl(
2619 cublasSgemmStridedBatched, stream, true /* = pointer_mode_host */,
2620 math_type, CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n,
2621 k, static_cast<const float *>(alpha),
2622 static_cast<const float *>(a.opaque()), lda, stride_a,
2623 static_cast<const float *>(b.opaque()), ldb, stride_b,
2624 static_cast<const float *>(beta), static_cast<float *>(c->opaque()),
2625 ldc, stride_c, batch_count);
2626 }
2627 case dnn::kDouble:
2628 return DoBlasInternalImpl(
2629 cublasDgemmStridedBatched, stream, true /* = pointer_mode_host */,
2630 math_type, CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n,
2631 k, static_cast<const double *>(alpha),
2632 static_cast<const double *>(a.opaque()), lda, stride_a,
2633 static_cast<const double *>(b.opaque()), ldb, stride_b,
2634 static_cast<const double *>(beta), static_cast<double *>(c->opaque()),
2635 ldc, stride_c, batch_count);
2636 case dnn::kComplexFloat: {
2637 GpuComplexType cb_alpha =
2638 GpuComplexValue(*static_cast<const std::complex<float> *>(alpha));
2639 GpuComplexType cb_beta =
2640 GpuComplexValue(*static_cast<const std::complex<float> *>(beta));
2641 return DoBlasInternalImpl(
2642 cublasCgemmStridedBatched, stream, true /* = pointer_mode_host */,
2643 math_type, CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n,
2644 k, GpuComplex(&cb_alpha),
2645 static_cast<const GpuComplexType *>(a.opaque()), lda, stride_a,
2646 static_cast<const GpuComplexType *>(b.opaque()), ldb, stride_b,
2647 GpuComplex(&cb_beta), static_cast<GpuComplexType *>(c->opaque()), ldc,
2648 stride_c, batch_count);
2649 }
2650 case dnn::kComplexDouble: {
2651 GpuDoubleComplexType cb_alpha =
2652 GpuComplexValue(*static_cast<const std::complex<double> *>(alpha));
2653 GpuDoubleComplexType cb_beta =
2654 GpuComplexValue(*static_cast<const std::complex<double> *>(beta));
2655 return DoBlasInternalImpl(
2656 cublasZgemmStridedBatched, stream, true /* = pointer_mode_host */,
2657 math_type, CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n,
2658 k, GpuComplex(&cb_alpha),
2659 static_cast<const GpuDoubleComplexType *>(a.opaque()), lda, stride_a,
2660 static_cast<const GpuDoubleComplexType *>(b.opaque()), ldb, stride_b,
2661 GpuComplex(&cb_beta),
2662 static_cast<GpuDoubleComplexType *>(c->opaque()), ldc, stride_c,
2663 batch_count);
2664 }
2665 default:
2666 return port::InternalError(absl::StrCat("Unsupported datatype for GEMM: ",
2667 blas::DataTypeString(dtype)));
2668 }
2669 }
2670
2671 bool CUDABlas::DoBlasHemm(Stream *stream, blas::Side side,
2672 blas::UpperLower uplo, uint64 m, uint64 n,
2673 std::complex<float> alpha,
2674 const DeviceMemory<std::complex<float>> &a, int lda,
2675 const DeviceMemory<std::complex<float>> &b, int ldb,
2676 std::complex<float> beta,
2677 DeviceMemory<std::complex<float>> *c, int ldc) {
2678 auto cb_alpha = GpuComplexValue(alpha);
2679 auto cb_beta = GpuComplexValue(beta);
2680 return DoBlasInternal(cublasChemm, stream, true /* = pointer_mode_host */,
2681 CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n,
2682 GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2683 GpuComplex(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
2684 GpuComplex(GpuMemoryMutable(c)), ldc);
2685 }
2686
2687 bool CUDABlas::DoBlasHemm(Stream *stream, blas::Side side,
2688 blas::UpperLower uplo, uint64 m, uint64 n,
2689 std::complex<double> alpha,
2690 const DeviceMemory<std::complex<double>> &a, int lda,
2691 const DeviceMemory<std::complex<double>> &b, int ldb,
2692 std::complex<double> beta,
2693 DeviceMemory<std::complex<double>> *c, int ldc) {
2694 auto cb_alpha = GpuComplexValue(alpha);
2695 auto cb_beta = GpuComplexValue(beta);
2696 return DoBlasInternal(cublasZhemm, stream, true /* = pointer_mode_host */,
2697 CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n,
2698 GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2699 GpuComplex(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
2700 GpuComplex(GpuMemoryMutable(c)), ldc);
2701 }
2702
2703 bool CUDABlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo,
2704 blas::Transpose trans, uint64 n, uint64 k,
2705 float alpha,
2706 const DeviceMemory<std::complex<float>> &a, int lda,
2707 float beta, DeviceMemory<std::complex<float>> *c,
2708 int ldc) {
2709 return DoBlasInternal(cublasCherk, stream, true /* = pointer_mode_host */,
2710 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2711 k, &alpha, GpuComplex(GpuMemory(a)), lda, &beta,
2712 GpuComplex(GpuMemoryMutable(c)), ldc);
2713 }
2714
2715 bool CUDABlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo,
2716 blas::Transpose trans, uint64 n, uint64 k,
2717 double alpha,
2718 const DeviceMemory<std::complex<double>> &a, int lda,
2719 double beta, DeviceMemory<std::complex<double>> *c,
2720 int ldc) {
2721 return DoBlasInternal(cublasZherk, stream, true /* = pointer_mode_host */,
2722 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2723 k, &alpha, GpuComplex(GpuMemory(a)), lda, &beta,
2724 GpuComplex(GpuMemoryMutable(c)), ldc);
2725 }
2726
2727 bool CUDABlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
2728 blas::Transpose trans, uint64 n, uint64 k,
2729 std::complex<float> alpha,
2730 const DeviceMemory<std::complex<float>> &a, int lda,
2731 const DeviceMemory<std::complex<float>> &b, int ldb,
2732 float beta, DeviceMemory<std::complex<float>> *c,
2733 int ldc) {
2734 auto cb_alpha = GpuComplexValue(alpha);
2735 return DoBlasInternal(cublasCher2k, stream, true /* = pointer_mode_host */,
2736 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2737 k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2738 GpuComplex(GpuMemory(b)), ldb, &beta,
2739 GpuComplex(GpuMemoryMutable(c)), ldc);
2740 }
2741
2742 bool CUDABlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
2743 blas::Transpose trans, uint64 n, uint64 k,
2744 std::complex<double> alpha,
2745 const DeviceMemory<std::complex<double>> &a, int lda,
2746 const DeviceMemory<std::complex<double>> &b, int ldb,
2747 double beta, DeviceMemory<std::complex<double>> *c,
2748 int ldc) {
2749 auto cb_alpha = GpuComplexValue(alpha);
2750 return DoBlasInternal(cublasZher2k, stream, true /* = pointer_mode_host */,
2751 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2752 k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2753 GpuComplex(GpuMemory(b)), ldb, &beta,
2754 GpuComplex(GpuMemoryMutable(c)), ldc);
2755 }
2756
2757 bool CUDABlas::DoBlasSymm(Stream *stream, blas::Side side,
2758 blas::UpperLower uplo, uint64 m, uint64 n,
2759 float alpha, const DeviceMemory<float> &a, int lda,
2760 const DeviceMemory<float> &b, int ldb, float beta,
2761 DeviceMemory<float> *c, int ldc) {
2762 return DoBlasInternal(cublasSsymm, stream, true /* = pointer_mode_host */,
2763 CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n,
2764 &alpha, GpuMemory(a), lda, GpuMemory(b), ldb, &beta,
2765 GpuMemoryMutable(c), ldc);
2766 }
2767
2768 bool CUDABlas::DoBlasSymm(Stream *stream, blas::Side side,
2769 blas::UpperLower uplo, uint64 m, uint64 n,
2770 double alpha, const DeviceMemory<double> &a, int lda,
2771 const DeviceMemory<double> &b, int ldb, double beta,
2772 DeviceMemory<double> *c, int ldc) {
2773 return DoBlasInternal(cublasDsymm, stream, true /* = pointer_mode_host */,
2774 CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n,
2775 &alpha, GpuMemory(a), lda, GpuMemory(b), ldb, &beta,
2776 GpuMemoryMutable(c), ldc);
2777 }
2778
2779 bool CUDABlas::DoBlasSymm(Stream *stream, blas::Side side,
2780 blas::UpperLower uplo, uint64 m, uint64 n,
2781 std::complex<float> alpha,
2782 const DeviceMemory<std::complex<float>> &a, int lda,
2783 const DeviceMemory<std::complex<float>> &b, int ldb,
2784 std::complex<float> beta,
2785 DeviceMemory<std::complex<float>> *c, int ldc) {
2786 auto cb_alpha = GpuComplexValue(alpha);
2787 auto cb_beta = GpuComplexValue(beta);
2788 return DoBlasInternal(cublasCsymm, stream, true /* = pointer_mode_host */,
2789 CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n,
2790 GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2791 GpuComplex(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
2792 GpuComplex(GpuMemoryMutable(c)), ldc);
2793 }
2794
2795 bool CUDABlas::DoBlasSymm(Stream *stream, blas::Side side,
2796 blas::UpperLower uplo, uint64 m, uint64 n,
2797 std::complex<double> alpha,
2798 const DeviceMemory<std::complex<double>> &a, int lda,
2799 const DeviceMemory<std::complex<double>> &b, int ldb,
2800 std::complex<double> beta,
2801 DeviceMemory<std::complex<double>> *c, int ldc) {
2802 auto cb_alpha = GpuComplexValue(alpha);
2803 auto cb_beta = GpuComplexValue(beta);
2804 return DoBlasInternal(cublasZsymm, stream, true /* = pointer_mode_host */,
2805 CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n,
2806 GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2807 GpuComplex(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
2808 GpuComplex(GpuMemoryMutable(c)), ldc);
2809 }
2810
2811 bool CUDABlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
2812 blas::Transpose trans, uint64 n, uint64 k,
2813 float alpha, const DeviceMemory<float> &a, int lda,
2814 float beta, DeviceMemory<float> *c, int ldc) {
2815 return DoBlasInternal(cublasSsyrk, stream, true /* = pointer_mode_host */,
2816 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2817 k, &alpha, GpuMemory(a), lda, &beta,
2818 GpuMemoryMutable(c), ldc);
2819 }
2820
2821 bool CUDABlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
2822 blas::Transpose trans, uint64 n, uint64 k,
2823 double alpha, const DeviceMemory<double> &a, int lda,
2824 double beta, DeviceMemory<double> *c, int ldc) {
2825 return DoBlasInternal(cublasDsyrk, stream, true /* = pointer_mode_host */,
2826 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2827 k, &alpha, GpuMemory(a), lda, &beta,
2828 GpuMemoryMutable(c), ldc);
2829 }
2830
2831 bool CUDABlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
2832 blas::Transpose trans, uint64 n, uint64 k,
2833 std::complex<float> alpha,
2834 const DeviceMemory<std::complex<float>> &a, int lda,
2835 std::complex<float> beta,
2836 DeviceMemory<std::complex<float>> *c, int ldc) {
2837 auto cb_alpha = GpuComplexValue(alpha);
2838 auto cb_beta = GpuComplexValue(beta);
2839 return DoBlasInternal(cublasCsyrk, stream, true /* = pointer_mode_host */,
2840 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2841 k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2842 GpuComplex(&cb_beta), GpuComplex(GpuMemoryMutable(c)),
2843 ldc);
2844 }
2845
2846 bool CUDABlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
2847 blas::Transpose trans, uint64 n, uint64 k,
2848 std::complex<double> alpha,
2849 const DeviceMemory<std::complex<double>> &a, int lda,
2850 std::complex<double> beta,
2851 DeviceMemory<std::complex<double>> *c, int ldc) {
2852 auto cb_alpha = GpuComplexValue(alpha);
2853 auto cb_beta = GpuComplexValue(beta);
2854 return DoBlasInternal(cublasZsyrk, stream, true /* = pointer_mode_host */,
2855 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2856 k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2857 GpuComplex(&cb_beta), GpuComplex(GpuMemoryMutable(c)),
2858 ldc);
2859 }
2860
2861 bool CUDABlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
2862 blas::Transpose trans, uint64 n, uint64 k,
2863 float alpha, const DeviceMemory<float> &a, int lda,
2864 const DeviceMemory<float> &b, int ldb, float beta,
2865 DeviceMemory<float> *c, int ldc) {
2866 return DoBlasInternal(cublasSsyr2k, stream, true /* = pointer_mode_host */,
2867 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2868 k, &alpha, GpuMemory(a), lda, GpuMemory(b), ldb, &beta,
2869 GpuMemoryMutable(c), ldc);
2870 }
2871
2872 bool CUDABlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
2873 blas::Transpose trans, uint64 n, uint64 k,
2874 double alpha, const DeviceMemory<double> &a, int lda,
2875 const DeviceMemory<double> &b, int ldb, double beta,
2876 DeviceMemory<double> *c, int ldc) {
2877 return DoBlasInternal(cublasDsyr2k, stream, true /* = pointer_mode_host */,
2878 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2879 k, &alpha, GpuMemory(a), lda, GpuMemory(b), ldb, &beta,
2880 GpuMemoryMutable(c), ldc);
2881 }
2882
2883 bool CUDABlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
2884 blas::Transpose trans, uint64 n, uint64 k,
2885 std::complex<float> alpha,
2886 const DeviceMemory<std::complex<float>> &a, int lda,
2887 const DeviceMemory<std::complex<float>> &b, int ldb,
2888 std::complex<float> beta,
2889 DeviceMemory<std::complex<float>> *c, int ldc) {
2890 auto cb_alpha = GpuComplexValue(alpha);
2891 auto cb_beta = GpuComplexValue(beta);
2892 return DoBlasInternal(cublasCsyr2k, stream, true /* = pointer_mode_host */,
2893 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2894 k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2895 GpuComplex(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
2896 GpuComplex(GpuMemoryMutable(c)), ldc);
2897 }
2898
2899 bool CUDABlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
2900 blas::Transpose trans, uint64 n, uint64 k,
2901 std::complex<double> alpha,
2902 const DeviceMemory<std::complex<double>> &a, int lda,
2903 const DeviceMemory<std::complex<double>> &b, int ldb,
2904 std::complex<double> beta,
2905 DeviceMemory<std::complex<double>> *c, int ldc) {
2906 auto cb_alpha = GpuComplexValue(alpha);
2907 auto cb_beta = GpuComplexValue(beta);
2908 return DoBlasInternal(cublasZsyr2k, stream, true /* = pointer_mode_host */,
2909 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2910 k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2911 GpuComplex(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
2912 GpuComplex(GpuMemoryMutable(c)), ldc);
2913 }
2914
2915 bool CUDABlas::DoBlasTrmm(Stream *stream, blas::Side side,
2916 blas::UpperLower uplo, blas::Transpose transa,
2917 blas::Diagonal diag, uint64 m, uint64 n, float alpha,
2918 const DeviceMemory<float> &a, int lda,
2919 DeviceMemory<float> *b, int ldb) {
2920 return DoBlasInternal(cublasStrmm, stream, true /* = pointer_mode_host */,
2921 CUDABlasSide(side), CUDABlasUpperLower(uplo),
2922 CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
2923 &alpha, GpuMemory(a), lda, GpuMemoryMutable(b), ldb,
2924 GpuMemoryMutable(b), ldb);
2925 }
2926
2927 bool CUDABlas::DoBlasTrmm(Stream *stream, blas::Side side,
2928 blas::UpperLower uplo, blas::Transpose transa,
2929 blas::Diagonal diag, uint64 m, uint64 n, double alpha,
2930 const DeviceMemory<double> &a, int lda,
2931 DeviceMemory<double> *b, int ldb) {
2932 return DoBlasInternal(cublasDtrmm, stream, true /* = pointer_mode_host */,
2933 CUDABlasSide(side), CUDABlasUpperLower(uplo),
2934 CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
2935 &alpha, GpuMemory(a), lda, GpuMemoryMutable(b), ldb,
2936 GpuMemoryMutable(b), ldb);
2937 }
2938
2939 bool CUDABlas::DoBlasTrmm(Stream *stream, blas::Side side,
2940 blas::UpperLower uplo, blas::Transpose transa,
2941 blas::Diagonal diag, uint64 m, uint64 n,
2942 std::complex<float> alpha,
2943 const DeviceMemory<std::complex<float>> &a, int lda,
2944 DeviceMemory<std::complex<float>> *b, int ldb) {
2945 auto cb_alpha = GpuComplexValue(alpha);
2946 return DoBlasInternal(cublasCtrmm, stream, true /* = pointer_mode_host */,
2947 CUDABlasSide(side), CUDABlasUpperLower(uplo),
2948 CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
2949 GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2950 GpuComplex(GpuMemoryMutable(b)), ldb,
2951 GpuComplex(GpuMemoryMutable(b)), ldb);
2952 }
2953
2954 bool CUDABlas::DoBlasTrmm(Stream *stream, blas::Side side,
2955 blas::UpperLower uplo, blas::Transpose transa,
2956 blas::Diagonal diag, uint64 m, uint64 n,
2957 std::complex<double> alpha,
2958 const DeviceMemory<std::complex<double>> &a, int lda,
2959 DeviceMemory<std::complex<double>> *b, int ldb) {
2960 auto cb_alpha = GpuComplexValue(alpha);
2961 return DoBlasInternal(cublasZtrmm, stream, true /* = pointer_mode_host */,
2962 CUDABlasSide(side), CUDABlasUpperLower(uplo),
2963 CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
2964 GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2965 GpuComplex(GpuMemoryMutable(b)), ldb,
2966 GpuComplex(GpuMemoryMutable(b)), ldb);
2967 }
2968
2969 bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
2970 blas::UpperLower uplo, blas::Transpose transa,
2971 blas::Diagonal diag, uint64 m, uint64 n, float alpha,
2972 const DeviceMemory<float> &a, int lda,
2973 DeviceMemory<float> *b, int ldb) {
2974 return DoBlasInternal(cublasStrsm, stream, true /* = pointer_mode_host */,
2975 CUDABlasSide(side), CUDABlasUpperLower(uplo),
2976 CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
2977 &alpha, GpuMemory(a), lda, GpuMemoryMutable(b), ldb);
2978 }
2979
2980 bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
2981 blas::UpperLower uplo, blas::Transpose transa,
2982 blas::Diagonal diag, uint64 m, uint64 n, double alpha,
2983 const DeviceMemory<double> &a, int lda,
2984 DeviceMemory<double> *b, int ldb) {
2985 return DoBlasInternal(cublasDtrsm, stream, true /* = pointer_mode_host */,
2986 CUDABlasSide(side), CUDABlasUpperLower(uplo),
2987 CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
2988 &alpha, GpuMemory(a), lda, GpuMemoryMutable(b), ldb);
2989 }
2990
2991 bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
2992 blas::UpperLower uplo, blas::Transpose transa,
2993 blas::Diagonal diag, uint64 m, uint64 n,
2994 std::complex<float> alpha,
2995 const DeviceMemory<std::complex<float>> &a, int lda,
2996 DeviceMemory<std::complex<float>> *b, int ldb) {
2997 auto cb_alpha = GpuComplexValue(alpha);
2998 return DoBlasInternal(cublasCtrsm, stream, true /* = pointer_mode_host */,
2999 CUDABlasSide(side), CUDABlasUpperLower(uplo),
3000 CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
3001 GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
3002 GpuComplex(GpuMemoryMutable(b)), ldb);
3003 }
3004
3005 bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
3006 blas::UpperLower uplo, blas::Transpose transa,
3007 blas::Diagonal diag, uint64 m, uint64 n,
3008 std::complex<double> alpha,
3009 const DeviceMemory<std::complex<double>> &a, int lda,
3010 DeviceMemory<std::complex<double>> *b, int ldb) {
3011 auto cb_alpha = GpuComplexValue(alpha);
3012 return DoBlasInternal(cublasZtrsm, stream, true /* = pointer_mode_host */,
3013 CUDABlasSide(side), CUDABlasUpperLower(uplo),
3014 CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
3015 GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
3016 GpuComplex(GpuMemoryMutable(b)), ldb);
3017 }
3018
3019 // We only use cublasLt from CUDA 11.0 onward.
3020 #if CUDA_VERSION >= 11000
3021
3022 namespace {
3023
3024 template <typename T>
3025 inline port::Status SetCublasLtAttr(cublasLtMatrixLayout_t handle,
3026 cublasLtMatrixLayoutAttribute_t attr,
3027 const T &value) {
3028 cublasStatus_t status =
3029 cublasLtMatrixLayoutSetAttribute(handle, attr, &value, sizeof(T));
3030 if (status != CUBLAS_STATUS_SUCCESS) {
3031 return port::Status(
3032 port::error::INTERNAL,
3033 absl::StrCat("cublasLtMatrixLayoutSetAttribute(attr=", attr,
3034 ", value=", value, ") failed: ", ToString(status)));
3035 }
3036 return port::Status::OK();
3037 }
3038
3039 template <typename T>
3040 inline port::Status SetCublasLtAttr(cublasLtMatmulAlgo_t *handle,
3041 cublasLtMatmulAlgoConfigAttributes_t attr,
3042 const T &value) {
3043 cublasStatus_t status =
3044 cublasLtMatmulAlgoConfigSetAttribute(handle, attr, &value, sizeof(T));
3045 if (status != CUBLAS_STATUS_SUCCESS) {
3046 return port::Status(
3047 port::error::INTERNAL,
3048 absl::StrCat("cublasLtMatmulAlgoConfigSetAttribute(attr=", attr,
3049 ", value=", value, ") failed: ", ToString(status)));
3050 }
3051 return port::Status::OK();
3052 }
3053
3054 template <typename T>
3055 inline port::Status SetCublasLtAttr(cublasLtMatmulPreference_t handle,
3056 cublasLtMatmulPreferenceAttributes_t attr,
3057 const T &value) {
3058 cublasStatus_t status =
3059 cublasLtMatmulPreferenceSetAttribute(handle, attr, &value, sizeof(value));
3060 if (status != CUBLAS_STATUS_SUCCESS) {
3061 return port::Status(
3062 port::error::INTERNAL,
3063 absl::StrCat("cublasLtMatmulPreferenceSetAttribute(attr=", attr,
3064 ", value=", value, ") failed: ", ToString(status)));
3065 }
3066 return port::Status::OK();
3067 }
3068
3069 template <typename T>
3070 inline bool GetCublasLtAttr(const cublasLtMatmulAlgo_t *handle,
3071 cublasLtMatmulAlgoConfigAttributes_t attr,
3072 T *value) {
3073 auto mutable_handle = const_cast<cublasLtMatmulAlgo_t *>(handle);
3074 size_t bytes_written = 0;
3075 return cublasLtMatmulAlgoConfigGetAttribute(mutable_handle, attr, value,
3076 sizeof(T), &bytes_written) ==
3077 CUBLAS_STATUS_SUCCESS &&
3078 bytes_written == sizeof(T);
3079 }
3080
3081 template <typename T>
3082 inline const T &ValueForStrCat(const T &value) {
3083 return value;
3084 }
3085 template <typename T>
3086 inline absl::Hex ValueForStrCat(T *ptr) {
3087 return absl::Hex(reinterpret_cast<uintptr_t>(ptr));
3088 }
3089
3090 template <typename T>
3091 inline port::Status SetCublasLtAttr(cublasLtMatmulDesc_t handle,
3092 cublasLtMatmulDescAttributes_t attr,
3093 const T &value) {
3094 cublasStatus_t status =
3095 cublasLtMatmulDescSetAttribute(handle, attr, &value, sizeof(value));
3096 if (status != CUBLAS_STATUS_SUCCESS) {
3097 return port::Status(
3098 port::error::INTERNAL,
3099 absl::StrCat("cublasLtMatmulDescSetAttribute(attr=", attr, ", value=",
3100 ValueForStrCat(value), ") failed: ", ToString(status)));
3101 }
3102 return port::Status::OK();
3103 }
3104
3105 struct MatmulDescDestroyer {
3106 void operator()(cublasLtMatmulDesc_t matmul_desc) const {
3107 cublasLtMatmulDescDestroy(matmul_desc);
3108 }
3109 };
3110 struct LayoutDestroyer {
3111 void operator()(cublasLtMatrixLayout_t layout) const {
3112 cublasLtMatrixLayoutDestroy(layout);
3113 }
3114 };
3115 struct MatmulPreferenceDestroyer {
3116 void operator()(cublasLtMatmulPreference_t matmul_pref) const {
3117 cublasLtMatmulPreferenceDestroy(matmul_pref);
3118 }
3119 };
3120 using UniqueOpDesc =
3121 std::unique_ptr<std::remove_pointer<cublasLtMatmulDesc_t>::type,
3122 MatmulDescDestroyer>;
3123 using UniqueLayoutDesc =
3124 std::unique_ptr<std::remove_pointer<cublasLtMatrixLayout_t>::type,
3125 LayoutDestroyer>;
3126 using UniqueMatmulPreference =
3127 std::unique_ptr<std::remove_pointer<cublasLtMatmulPreference_t>::type,
3128 MatmulPreferenceDestroyer>;
3129
3130 port::StatusOr<UniqueOpDesc> CreateCublasLtOperationDesc(
3131 blas::ComputationType computation_type, blas::DataType scale_type,
3132 blas::PointerMode pointer_mode, blas::Epilogue epilogue,
3133 blas::Transpose transa, blas::Transpose transb) {
3134 cublasLtMatmulDesc_t desc;
3135 cublasComputeType_t cublas_compute_type =
3136 CUBLASComputationType(computation_type);
3137 cudaDataType_t cuda_scale_type = GetCUDADataType(scale_type);
3138 cublasStatus_t status =
3139 cublasLtMatmulDescCreate(&desc, cublas_compute_type, cuda_scale_type);
3140 if (status != CUBLAS_STATUS_SUCCESS) {
3141 return port::Status(
3142 port::error::INTERNAL,
3143 absl::StrCat("cublasLtMatmulDescCreate(computation_type=",
3144 computation_type, ") failed: ", ToString(status)));
3145 }
3146 UniqueOpDesc unique_desc(desc);
3147 SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_POINTER_MODE,
3148 CUBLASPointerMode(pointer_mode)));
3149 SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_EPILOGUE,
3150 CUBLASEpilogue(epilogue)));
3151 SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSA,
3152 CUDABlasTranspose(transa)));
3153 SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSB,
3154 CUDABlasTranspose(transb)));
3155 return unique_desc;
3156 }
3157
3158 port::StatusOr<UniqueLayoutDesc> CreateCublasLtLayoutDesc(
3159 blas::DataType data_type, uint64 rows, uint64 cols, int64_t ld,
3160 int64_t stride, int batch_count) {
3161 cublasLtMatrixLayout_t desc;
3162 cublasStatus_t status = cublasLtMatrixLayoutCreate(
3163 &desc, GetCUDADataType(data_type), rows, cols, ld);
3164 if (status != CUBLAS_STATUS_SUCCESS) {
3165 return port::Status(
3166 port::error::INTERNAL,
3167 absl::StrCat("cublasLtMatrixLayoutCreate failed: ", ToString(status)));
3168 }
3169 UniqueLayoutDesc unique_desc(desc);
3170 SE_RETURN_IF_ERROR(
3171 SetCublasLtAttr(desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, batch_count));
3172 SE_RETURN_IF_ERROR(SetCublasLtAttr(
3173 desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stride));
3174 return unique_desc;
3175 }
3176
3177 // Helper function to allocate workspace.
3178 port::Status AllocateWorkspace(void **workspace,
3179 ScratchAllocator *scratch_allocator,
3180 size_t num_bytes) {
3181 SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace_bytes,
3182 scratch_allocator->AllocateBytes(num_bytes));
3183 *workspace = (void *)GpuMemoryMutable(&workspace_bytes);
3184 return port::Status::OK();
3185 }
3186
3187 template <typename T>
3188 blas::ComputationType ToComputationType();
3189 template <>
3190 blas::ComputationType ToComputationType<Eigen::half>() {
3191 return blas::ComputationType::kF16;
3192 }
3193 template <>
3194 blas::ComputationType ToComputationType<float>() {
3195 return blas::ComputationType::kF32;
3196 }
3197 template <>
3198 blas::ComputationType ToComputationType<double>() {
3199 return blas::ComputationType::kF64;
3200 }
3201 template <>
3202 blas::ComputationType ToComputationType<std::complex<float>>() {
3203 return blas::ComputationType::kComplexF32;
3204 }
3205 template <>
3206 blas::ComputationType ToComputationType<std::complex<double>>() {
3207 return blas::ComputationType::kComplexF64;
3208 }
3209
3210 class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
3211 public:
3212 port::Status init(const blas::BlasLtMatmulPlanParams &p) {
3213 params_ = p;
3214 scale_type_ = GetScaleType(p.c_type, p.computation_type);
3215 SE_ASSIGN_OR_RETURN(
3216 op_desc_,
3217 CreateCublasLtOperationDesc(
3218 p.computation_type, GetScaleType(p.c_type, p.computation_type),
3219 p.pointer_mode, p.epilogue, p.transa, p.transb));
3220 uint64 rows_a = p.transa == blas::Transpose::kNoTranspose ? p.m : p.k;
3221 uint64 cols_a = p.transa == blas::Transpose::kNoTranspose ? p.k : p.m;
3222 uint64 rows_b = p.transb == blas::Transpose::kNoTranspose ? p.k : p.n;
3223 uint64 cols_b = p.transb == blas::Transpose::kNoTranspose ? p.n : p.k;
3224 SE_ASSIGN_OR_RETURN(
3225 a_desc_, CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda,
3226 p.stride_a, capped_batch_count()));
3227 SE_ASSIGN_OR_RETURN(
3228 b_desc_, CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb,
3229 p.stride_b, capped_batch_count()));
3230 SE_ASSIGN_OR_RETURN(
3231 c_desc_, CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
3232 capped_batch_count()));
3233 SE_ASSIGN_OR_RETURN(
3234 d_desc_, CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
3235 capped_batch_count()));
3236 remainder_batch_count_ =
3237 p.batch_count > kMaxBatchCount ? p.batch_count % kMaxBatchCount : 0;
3238 if (remainder_batch_count_) {
3239 SE_ASSIGN_OR_RETURN(
3240 a_remainder_desc_,
3241 CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda, p.stride_a,
3242 remainder_batch_count_));
3243 SE_ASSIGN_OR_RETURN(
3244 b_remainder_desc_,
3245 CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb, p.stride_b,
3246 remainder_batch_count_));
3247 SE_ASSIGN_OR_RETURN(
3248 c_remainder_desc_,
3249 CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
3250 remainder_batch_count_));
3251 SE_ASSIGN_OR_RETURN(
3252 d_remainder_desc_,
3253 CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
3254 remainder_batch_count_));
3255 }
3256 return port::Status::OK();
3257 }
3258
3259 cublasLtMatmulDesc_t op_desc() const { return op_desc_.get(); }
3260 cublasLtMatrixLayout_t a_desc() const { return a_desc_.get(); }
3261 cublasLtMatrixLayout_t b_desc() const { return b_desc_.get(); }
3262 cublasLtMatrixLayout_t c_desc() const { return c_desc_.get(); }
3263 cublasLtMatrixLayout_t d_desc() const { return d_desc_.get(); }
3264 cublasLtMatrixLayout_t a_remainder_desc() const {
3265 return a_remainder_desc_.get();
3266 }
3267 cublasLtMatrixLayout_t b_remainder_desc() const {
3268 return b_remainder_desc_.get();
3269 }
3270 cublasLtMatrixLayout_t c_remainder_desc() const {
3271 return c_remainder_desc_.get();
3272 }
3273 cublasLtMatrixLayout_t d_remainder_desc() const {
3274 return d_remainder_desc_.get();
3275 }
3276
3277 const blas::BlasLtMatmulPlanParams ¶ms() const { return params_; }
3278 blas::DataType scale_type() const { return scale_type_; }
3279 blas::DataType ab_type() const override { return params_.ab_type; }
3280 blas::DataType c_type() const override { return params_.c_type; }
3281 int capped_batch_count() const {
3282 return std::min(params_.batch_count, kMaxBatchCount);
3283 }
3284 int remainder_batch_count() const { return remainder_batch_count_; }
3285
3286 // Note: Must be const to satisfy API. This is always called before the plan
3287 // is executed, so the state change is not observed in subsequent executions.
3288 bool SetBiasPointer(const void *bias) const;
3289
3290 private:
3291 // In some cases cublasLt does not support large batch sizes, so we need to
3292 // split up such cases into multiple calls.
3293 static constexpr int kMaxBatchCount = 65535;
3294 blas::BlasLtMatmulPlanParams params_;
3295 blas::DataType scale_type_;
3296 UniqueOpDesc op_desc_;
3297 // These have batch count set to capped_batch_count().
3298 UniqueLayoutDesc a_desc_;
3299 UniqueLayoutDesc b_desc_;
3300 UniqueLayoutDesc c_desc_;
3301 UniqueLayoutDesc d_desc_;
3302 int remainder_batch_count_;
3303 // These have batch count set to remainder_batch_count_, and are only created
3304 // if params_.batch_count > kMaxBatchSize.
3305 UniqueLayoutDesc a_remainder_desc_;
3306 UniqueLayoutDesc b_remainder_desc_;
3307 UniqueLayoutDesc c_remainder_desc_;
3308 UniqueLayoutDesc d_remainder_desc_;
3309 };
3310
3311 /*static*/ constexpr int CUDABlasLtMatmulPlan::kMaxBatchCount;
3312
3313 bool CUDABlasLtMatmulPlan::SetBiasPointer(const void *bias) const {
3314 return SetCublasLtAttr(op_desc_.get(), CUBLASLT_MATMUL_DESC_BIAS_POINTER,
3315 bias)
3316 .ok();
3317 }
3318
3319 class CUDABlasLtMatmulAlgorithm final : public blas::IBlasLtMatmulAlgorithm {
3320 public:
3321 CUDABlasLtMatmulAlgorithm(blas::AlgorithmType index,
3322 cublasLtMatmulAlgo_t algo, size_t workspace_size)
3323 : index_(index), algo_(algo), workspace_size_(workspace_size) {}
3324
3325 blas::AlgorithmType index() const override { return index_; }
3326
3327 size_t workspace_size() const override { return workspace_size_; }
3328
3329 const cublasLtMatmulAlgo_t *algo() const { return &algo_; }
3330
3331 int algo_id() const {
3332 int id;
3333 GetCublasLtAttr(&algo_, CUBLASLT_ALGO_CONFIG_ID, &id);
3334 return id;
3335 }
3336
3337 private:
3338 blas::AlgorithmType index_;
3339 cublasLtMatmulAlgo_t algo_;
3340 size_t workspace_size_;
3341 };
3342
3343 port::StatusOr<UniqueMatmulPreference> CreateCublasLtMatmulPreference(
3344 const blas::IBlasLtMatmulPlan *plan, size_t max_workspace_bytes) {
3345 cublasLtMatmulPreference_t preference;
3346 cublasStatus_t status = cublasLtMatmulPreferenceCreate(&preference);
3347 if (status != CUBLAS_STATUS_SUCCESS) {
3348 return port::Status(port::error::INTERNAL,
3349 absl::StrCat("cublasLtMatmulPreferenceCreate failed: ",
3350 ToString(status)));
3351 }
3352 UniqueMatmulPreference unique_preference(preference);
3353 SE_RETURN_IF_ERROR(SetCublasLtAttr(preference,
3354 CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
3355 max_workspace_bytes));
3356
3357 const auto &cuda_plan = *static_cast<const CUDABlasLtMatmulPlan *>(plan);
3358 if (cuda_plan.params().batch_count == 0) {
3359 return unique_preference;
3360 }
3361 // This is a workaround for a known issue in cuBlasLt where the heuristic may
3362 // in rare cases select an algo that does not support the specified stride.
3363 // Specifying the alignment requirements manually like this avoids the issue.
3364 auto get_alignment_bytes = [](int64_t stride, blas::DataType dtype) {
3365 return (stride & -stride) * GetDataTypeSizeBytes(dtype);
3366 };
3367 if (cuda_plan.params().stride_a) {
3368 SE_RETURN_IF_ERROR(SetCublasLtAttr(
3369 preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES,
3370 (uint32)get_alignment_bytes(cuda_plan.params().stride_a,
3371 cuda_plan.params().ab_type)));
3372 }
3373 if (cuda_plan.params().stride_b) {
3374 SE_RETURN_IF_ERROR(SetCublasLtAttr(
3375 preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES,
3376 (uint32)get_alignment_bytes(cuda_plan.params().stride_b,
3377 cuda_plan.params().ab_type)));
3378 }
3379 if (cuda_plan.params().stride_c) {
3380 SE_RETURN_IF_ERROR(SetCublasLtAttr(
3381 preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES,
3382 (uint32)get_alignment_bytes(cuda_plan.params().stride_c,
3383 cuda_plan.params().c_type)));
3384 }
3385 if (cuda_plan.params().stride_c) {
3386 SE_RETURN_IF_ERROR(SetCublasLtAttr(
3387 preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES,
3388 (uint32)get_alignment_bytes(cuda_plan.params().stride_c,
3389 cuda_plan.params().c_type)));
3390 }
3391 return unique_preference;
3392 }
3393
3394 } // namespace
3395
3396 #endif // CUDA_VERSION >= 11000
3397
3398 port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
3399 CUDABlas::CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams &p) {
3400 #if CUDA_VERSION >= 11000
3401 auto cuda_plan = std::make_unique<CUDABlasLtMatmulPlan>();
3402 SE_RETURN_IF_ERROR(cuda_plan->init(p));
3403 return static_cast<std::unique_ptr<blas::IBlasLtMatmulPlan>>(
3404 std::move(cuda_plan));
3405 #else
3406 return port::Status(
3407 port::error::UNIMPLEMENTED,
3408 "CreateBlasLtMatmulPlan is not supported with this version of CUDA");
3409 #endif
3410 }
3411
3412 port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
3413 CUDABlas::GetBlasLtMatmulAlgorithmsInternal(const blas::IBlasLtMatmulPlan *plan,
3414 size_t max_workspace_size,
3415 int max_algorithm_count,
3416 bool for_remainder_batch) {
3417 #if CUDA_VERSION >= 11000
3418 SE_ASSIGN_OR_RETURN(UniqueMatmulPreference preference,
3419 CreateCublasLtMatmulPreference(plan, max_workspace_size));
3420
3421 std::vector<cublasLtMatmulHeuristicResult_t> results(max_algorithm_count);
3422 {
3423 absl::MutexLock lock(&mu_);
3424
3425 CHECK(blasLt_ != nullptr);
3426
3427 gpu::ScopedActivateExecutorContext sac{parent_};
3428
3429 int found_algorithm_count = 0;
3430 const auto &cuda_plan = *static_cast<const CUDABlasLtMatmulPlan *>(plan);
3431 const auto &a_desc =
3432 for_remainder_batch ? cuda_plan.a_remainder_desc() : cuda_plan.a_desc();
3433 const auto &b_desc =
3434 for_remainder_batch ? cuda_plan.b_remainder_desc() : cuda_plan.b_desc();
3435 const auto &c_desc =
3436 for_remainder_batch ? cuda_plan.c_remainder_desc() : cuda_plan.c_desc();
3437 const auto &d_desc =
3438 for_remainder_batch ? cuda_plan.d_remainder_desc() : cuda_plan.d_desc();
3439 cublasStatus_t status = cublasLtMatmulAlgoGetHeuristic(
3440 blasLt_, cuda_plan.op_desc(), a_desc, b_desc, c_desc, d_desc,
3441 preference.get(), max_algorithm_count, results.data(),
3442 &found_algorithm_count);
3443 if (status != CUBLAS_STATUS_SUCCESS) {
3444 return port::Status(
3445 port::error::INTERNAL,
3446 absl::StrCat("cublasLtMatmulAlgoGetHeuristic failed: ",
3447 ToString(status)));
3448 }
3449 results.resize(found_algorithm_count);
3450 }
3451
3452 std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>> out_algorithms;
3453 out_algorithms.reserve(results.size());
3454 for (size_t i = 0; i < results.size(); ++i) {
3455 const auto &result = results[i];
3456 if (result.state != CUBLAS_STATUS_SUCCESS) continue; // Skip failed algos
3457 out_algorithms.emplace_back(std::make_unique<CUDABlasLtMatmulAlgorithm>(
3458 i, result.algo, result.workspaceSize));
3459 }
3460 return out_algorithms;
3461 #else // if CUDA_VERSION < 11000
3462 return port::Status(
3463 port::error::UNIMPLEMENTED,
3464 "GetBlasLtMatmulAlgorithms is not supported with this version of CUDA");
3465 #endif
3466 }
3467
3468 port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
3469 CUDABlas::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
3470 size_t max_workspace_size,
3471 int max_algorithm_count) {
3472 return GetBlasLtMatmulAlgorithmsInternal(plan, max_workspace_size,
3473 max_algorithm_count);
3474 }
3475
3476 #if CUDA_VERSION >= 11000
3477 bool CUDABlas::DoBlasLtMatmulInternal(
3478 Stream *stream, bool err_on_failure, const blas::IBlasLtMatmulPlan *plan,
3479 const HostOrDeviceScalar<void> &alpha, DeviceMemoryBase a,
3480 DeviceMemoryBase b, const HostOrDeviceScalar<void> &beta,
3481 DeviceMemoryBase c, DeviceMemoryBase d, ScratchAllocator *scratch_allocator,
3482 const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias) {
3483 const auto &cuda_plan = *static_cast<const CUDABlasLtMatmulPlan *>(plan);
3484 const auto &cuda_algo =
3485 *static_cast<const CUDABlasLtMatmulAlgorithm *>(algorithm);
3486
3487 if (alpha.data_type() != cuda_plan.scale_type() ||
3488 beta.data_type() != cuda_plan.scale_type()) {
3489 VLOG(2) << "DoBlasLtMatmul returning false because alpha and beta types do "
3490 "not match plan: expected "
3491 << cuda_plan.c_type() << ", got alpha=" << alpha.data_type()
3492 << " beta=" << beta.data_type();
3493 return false;
3494 }
3495 if (alpha.is_pointer() != beta.is_pointer()) {
3496 VLOG(2) << "DoBlasLtMatmul returning false because one of `alpha` "
3497 "and `beta` is a pointer, but the other is not.";
3498 return false;
3499 }
3500 bool is_pointer_mode_host = !alpha.is_pointer();
3501 if ((cuda_plan.params().pointer_mode == blas::PointerMode::kHost) !=
3502 is_pointer_mode_host) {
3503 VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong "
3504 "pointer_mode for the given alpha/beta.";
3505 return false;
3506 }
3507 if ((cuda_plan.params().epilogue == blas::Epilogue::kBias ||
3508 cuda_plan.params().epilogue == blas::Epilogue::kBiasThenReLU) !=
3509 (bias != nullptr)) {
3510 VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong "
3511 "epilogue for the given bias pointer.";
3512 return false;
3513 }
3514 const void *alpha_ptr = alpha.is_pointer() ? alpha.opaque_pointer().opaque()
3515 : alpha.opaque_value();
3516 const void *beta_ptr =
3517 beta.is_pointer() ? beta.opaque_pointer().opaque() : beta.opaque_value();
3518
3519 void *workspace = nullptr;
3520 if (cuda_algo.workspace_size()) {
3521 port::Status allocation_status = AllocateWorkspace(
3522 &workspace, scratch_allocator, cuda_algo.workspace_size());
3523 if (!allocation_status.ok()) {
3524 if (err_on_failure || VLOG_IS_ON(3)) {
3525 LOG(ERROR)
3526 << "Failed to allocate workspace for cublasLtMatmul algo with id: "
3527 << cuda_algo.algo_id() << " requiring "
3528 << cuda_algo.workspace_size() << " bytes of workspace";
3529 }
3530 return false;
3531 }
3532 }
3533
3534 // This is only used when batch_count > kMaxBatchCount.
3535 std::unique_ptr<blas::IBlasLtMatmulAlgorithm> unique_remainder_algo;
3536 if (cuda_plan.remainder_batch_count()) {
3537 // There is no easy way to get the user-specified max workspace size here,
3538 // so we just allow a very small amount and don't worry too much about
3539 // performance because this is only used in rare cases. The same reasoning
3540 // applies to selection of the algorithm.
3541 size_t max_workspace_size = 4 * 1024 * 1024; // 4 MiB
3542 auto status_or_algorithms =
3543 GetBlasLtMatmulAlgorithmsInternal(plan, max_workspace_size,
3544 /* max_algorithm_count = */ 1,
3545 /* for_remainder_batch = */ true);
3546 if (!status_or_algorithms.ok()) {
3547 if (err_on_failure || VLOG_IS_ON(3)) {
3548 LOG(ERROR) << "Failed to get algorithms for blasLt remainder batch.";
3549 }
3550 return false;
3551 }
3552 auto algorithms = status_or_algorithms.ConsumeValueOrDie();
3553 unique_remainder_algo = std::move(algorithms.front());
3554 }
3555
3556 cudaStream_t cuda_stream = CUDAStream(stream);
3557
3558 absl::MutexLock lock(&mu_);
3559
3560 if (bias != nullptr) {
3561 if (!cuda_plan.SetBiasPointer(bias.opaque())) {
3562 VLOG(2) << "DoBlasLtMatmul returning false because setting the bias "
3563 "pointer failed.";
3564 return false;
3565 }
3566 }
3567
3568 CHECK(blasLt_ != nullptr);
3569
3570 gpu::ScopedActivateExecutorContext sac{parent_};
3571
3572 // Plan execution is broken down into repeat calls with capped_batch_count,
3573 // followed by a final call with remainder_batch_count.
3574 // Cases where batch_count <= kMaxBatchCount require only a single call (a
3575 // single loop iteration and no remainder).
3576 int ab_type_size = GetDataTypeSizeBytes(cuda_plan.params().ab_type);
3577 int c_type_size = GetDataTypeSizeBytes(cuda_plan.params().c_type);
3578 const char *a_ptr = static_cast<const char *>(a.opaque());
3579 const char *b_ptr = static_cast<const char *>(b.opaque());
3580 const char *c_ptr = static_cast<const char *>(c.opaque());
3581 char *d_ptr = static_cast<char *>(d.opaque());
3582 int capped_batch_count = cuda_plan.capped_batch_count();
3583 for (int batch = 0;
3584 batch + capped_batch_count <= cuda_plan.params().batch_count;
3585 batch += capped_batch_count) {
3586 cublasStatus_t ret = cublasLtMatmul(
3587 blasLt_, cuda_plan.op_desc(), alpha_ptr, a_ptr, cuda_plan.a_desc(),
3588 b_ptr, cuda_plan.b_desc(), beta_ptr, c_ptr, cuda_plan.c_desc(), d_ptr,
3589 cuda_plan.d_desc(), cuda_algo.algo(), workspace,
3590 cuda_algo.workspace_size(), cuda_stream);
3591 if (ret != CUBLAS_STATUS_SUCCESS) {
3592 if (err_on_failure || VLOG_IS_ON(3)) {
3593 LOG(ERROR) << "failed to run cublasLtMatmul routine: " << ToString(ret);
3594 }
3595 return false;
3596 }
3597 a_ptr += capped_batch_count * cuda_plan.params().stride_a * ab_type_size;
3598 b_ptr += capped_batch_count * cuda_plan.params().stride_b * ab_type_size;
3599 c_ptr += capped_batch_count * cuda_plan.params().stride_c * c_type_size;
3600 d_ptr += capped_batch_count * cuda_plan.params().stride_c * c_type_size;
3601 }
3602 // This is only used when batch_count > kMaxBatchCount.
3603 if (cuda_plan.remainder_batch_count()) {
3604 const auto &remainder_algo =
3605 *static_cast<const CUDABlasLtMatmulAlgorithm *>(
3606 unique_remainder_algo.get());
3607 if (remainder_algo.workspace_size()) {
3608 port::Status allocation_status = AllocateWorkspace(
3609 &workspace, scratch_allocator, remainder_algo.workspace_size());
3610 if (!allocation_status.ok()) {
3611 if (err_on_failure || VLOG_IS_ON(3)) {
3612 LOG(ERROR) << "Failed to allocate workspace for cublasLtMatmul algo "
3613 "with id: "
3614 << remainder_algo.algo_id() << " requiring "
3615 << remainder_algo.workspace_size()
3616 << " bytes of workspace";
3617 }
3618 return false;
3619 }
3620 }
3621 cublasStatus_t ret = cublasLtMatmul(
3622 blasLt_, cuda_plan.op_desc(), alpha_ptr, a_ptr,
3623 cuda_plan.a_remainder_desc(), b_ptr, cuda_plan.b_remainder_desc(),
3624 beta_ptr, c_ptr, cuda_plan.c_remainder_desc(), d_ptr,
3625 cuda_plan.d_remainder_desc(), remainder_algo.algo(), workspace,
3626 remainder_algo.workspace_size(), cuda_stream);
3627 if (ret != CUBLAS_STATUS_SUCCESS) {
3628 if (err_on_failure || VLOG_IS_ON(3)) {
3629 LOG(ERROR) << "failed to run remainder cublasLtMatmul routine: "
3630 << ToString(ret);
3631 }
3632 return false;
3633 }
3634 }
3635 return true;
3636 }
3637 #endif // CUDA_VERSION >= 11000
3638
3639 bool CUDABlas::DoBlasLtMatmul(
3640 Stream *stream, const blas::IBlasLtMatmulPlan *plan,
3641 const HostOrDeviceScalar<void> &alpha, DeviceMemoryBase a,
3642 DeviceMemoryBase b, const HostOrDeviceScalar<void> &beta,
3643 DeviceMemoryBase c, ScratchAllocator *scratch_allocator,
3644 const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias,
3645 blas::ProfileResult *output_profile_result) {
3646 #if CUDA_VERSION >= 11000
3647 const auto &cuda_plan = *static_cast<const CUDABlasLtMatmulPlan *>(plan);
3648 HostOrDeviceScalar<void> alpha_cast = alpha;
3649 HostOrDeviceScalar<void> beta_cast = beta;
3650 if (cuda_plan.c_type() == blas::DataType::kHalf &&
3651 cuda_plan.scale_type() == blas::DataType::kFloat) {
3652 // The given alpha and beta types are F16 (they always match c), but F32*
3653 // computation type requires that they be F32, so we must cast them.
3654 if (alpha.is_pointer() || beta.is_pointer()) {
3655 // We cannot easily convert a pointer to f16 memory to a pointer to f32
3656 // memory from here, so we don't support this for now.
3657 return false;
3658 }
3659 alpha_cast = HostOrDeviceScalar<void>(
3660 static_cast<float>(alpha.value<Eigen::half>()));
3661 beta_cast =
3662 HostOrDeviceScalar<void>(static_cast<float>(beta.value<Eigen::half>()));
3663 }
3664
3665 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
3666 if (output_profile_result) {
3667 timer.reset(new GpuTimer(parent_));
3668 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
3669 return false;
3670 }
3671 }
3672
3673 bool err_on_failure = timer != nullptr;
3674 bool result = DoBlasLtMatmulInternal(stream, err_on_failure, plan, alpha_cast,
3675 a, b, beta_cast, c, c, scratch_allocator,
3676 algorithm, bias);
3677
3678 if (timer && result) {
3679 // GpuTimer will CHECK-fail if we Stop() it while the stream is in an error
3680 // state.
3681 if (!timer->Stop(AsGpuStream(stream))) {
3682 return false;
3683 }
3684 output_profile_result->set_is_valid(true);
3685 output_profile_result->set_algorithm(algorithm->index());
3686 output_profile_result->set_elapsed_time_in_ms(
3687 timer->GetElapsedMilliseconds());
3688 }
3689 return result;
3690 #else // if CUDA_VERSION < 11000
3691 return false;
3692 #endif
3693 }
3694
3695 port::Status CUDABlas::GetVersion(std::string *version) {
3696 absl::MutexLock lock(&mu_);
3697
3698 int v;
3699 auto status = cublasGetVersion(blas_, &v);
3700 if (status != CUBLAS_STATUS_SUCCESS) {
3701 return port::InternalError(ToString(status));
3702 }
3703 *version = std::to_string(v);
3704 return port::Status::OK();
3705 }
3706
3707 } // namespace gpu
3708
initialize_cublas()3709 void initialize_cublas() {
3710 port::Status status =
3711 PluginRegistry::Instance()->RegisterFactory<PluginRegistry::BlasFactory>(
3712 cuda::kCudaPlatformId, gpu::kCuBlasPlugin, "cuBLAS",
3713 [](internal::StreamExecutorInterface *parent) -> blas::BlasSupport * {
3714 gpu::GpuExecutor *cuda_executor =
3715 dynamic_cast<gpu::GpuExecutor *>(parent);
3716 if (cuda_executor == nullptr) {
3717 LOG(ERROR)
3718 << "Attempting to initialize an instance of the cuBLAS "
3719 << "support library with a non-CUDA StreamExecutor";
3720 return nullptr;
3721 }
3722
3723 gpu::CUDABlas *blas = new gpu::CUDABlas(cuda_executor);
3724 if (!blas->Init()) {
3725 // Note: Init() will log a more specific error.
3726 delete blas;
3727 return nullptr;
3728 }
3729 return blas;
3730 });
3731
3732 if (!status.ok()) {
3733 LOG(ERROR) << "Unable to register cuBLAS factory: "
3734 << status.error_message();
3735 }
3736
3737 PluginRegistry::Instance()->SetDefaultFactory(
3738 cuda::kCudaPlatformId, PluginKind::kBlas, gpu::kCuBlasPlugin);
3739 }
3740
3741 } // namespace stream_executor
3742
3743 REGISTER_MODULE_INITIALIZER(register_cublas,
3744 { stream_executor::initialize_cublas(); });
3745