1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "third_party/gpus/cuda/include/cublasLt.h"
17 #include "third_party/gpus/cuda/include/cublas_v2.h"
18 #include "third_party/gpus/cuda/include/cuda.h"
19
20 #define SE_CUDA_DATA_HALF CUDA_R_16F
21
22 #include "tensorflow/stream_executor/cuda/cuda_blas.h"
23
24 // Both Eigen Half.h and CUDA cuda_fp16.h provide similar typedef for __half. As
25 // such, there are two ways to get the typedef for __half:
26 //
27 // (1) Includes cuda_fp16.h and defines EIGEN_HAS_CUDA_FP16.
28 // (2) Neither includes cuda_fp16.h nor defines EIGEN_HAS_CUDA_FP16.
29 //
30 // Due to issue b/73793421, when the first approach is used and NVCC is used to
31 // compile this file, NVCC will complain duplicated definition for
32 // EIGEN_HAS_CUDA_FP16. On the other hand, when the second approach is used and
33 // clang is used to compile this file, clang will not understand __half
34 // due to missing the definition and macro EIGEN_HAS_CUDA_FP16.
35 //
36 // Because this file may be compiled with CLANG but will never be compiled with
37 // NVCC, we choose the first approach for CUDA < 9.0. For CUDA >= 9.0, we have
38 // to use the second approach because the data member in the __half defined
39 // by CUDA > 9.0 is `__x` while Eigen expects it to be `x`.
40 //
41 // TODO(b/73793421): Remove the following code block to switch to the second
42 // approach when the issue is fixed.
43 #if CUDA_VERSION < 9000
44 #include "third_party/gpus/cuda/include/cuda_fp16.h"
45 #define EIGEN_HAS_CUDA_FP16
46 #endif
47
48 #include <complex>
49
50 #include "absl/strings/str_cat.h"
51 #include "absl/strings/str_format.h"
52 #include "third_party/eigen3/Eigen/Core"
53 #include "tensorflow/core/platform/tensor_float_32_utils.h"
54 #include "tensorflow/core/util/env_var.h"
55 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
56 #include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
57 #include "tensorflow/stream_executor/cuda/cuda_helpers.h"
58 #include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
59 #include "tensorflow/stream_executor/cuda/cuda_stream.h"
60 #include "tensorflow/stream_executor/cuda/cuda_timer.h"
61 #include "tensorflow/stream_executor/device_memory.h"
62 #include "tensorflow/stream_executor/lib/env.h"
63 #include "tensorflow/stream_executor/lib/initialize.h"
64 #include "tensorflow/stream_executor/lib/status.h"
65 #include "tensorflow/stream_executor/lib/status_macros.h"
66 #include "tensorflow/stream_executor/platform/logging.h"
67 #include "tensorflow/stream_executor/platform/port.h"
68 #include "tensorflow/stream_executor/plugin_registry.h"
69 #include "tensorflow/stream_executor/scratch_allocator.h"
70 #include "tensorflow/stream_executor/stream_executor.h"
71
72 namespace stream_executor {
73 namespace gpu {
74
75 PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuBlasPlugin);
76
ToString(cublasStatus_t status)77 static std::string ToString(cublasStatus_t status) {
78 switch (status) {
79 case CUBLAS_STATUS_SUCCESS:
80 return "CUBLAS_STATUS_SUCCESS";
81 case CUBLAS_STATUS_NOT_INITIALIZED:
82 return "CUBLAS_STATUS_NOT_INITIALIZED";
83 case CUBLAS_STATUS_ALLOC_FAILED:
84 return "CUBLAS_STATUS_ALLOC_FAILED";
85 case CUBLAS_STATUS_INVALID_VALUE:
86 return "CUBLAS_STATUS_INVALID_VALUE";
87 case CUBLAS_STATUS_ARCH_MISMATCH:
88 return "CUBLAS_STATUS_ARCH_MISMATCH";
89 case CUBLAS_STATUS_MAPPING_ERROR:
90 return "CUBLAS_STATUS_MAPPING_ERROR";
91 case CUBLAS_STATUS_EXECUTION_FAILED:
92 return "CUBLAS_STATUS_EXECUTION_FAILED";
93 case CUBLAS_STATUS_INTERNAL_ERROR:
94 return "CUBLAS_STATUS_INTERNAL_ERROR";
95 #if CUDA_VERSION >= 8000
96 case CUBLAS_STATUS_NOT_SUPPORTED:
97 return "CUBLAS_STATUS_NOT_SUPPORTED";
98 case CUBLAS_STATUS_LICENSE_ERROR:
99 return "CUBLAS_STATUS_LICENSE_ERROR";
100 #endif
101 default:
102 return absl::StrCat("<invalid cublas status: ", status, ">");
103 }
104 }
105
106 // cuBLAS has interfaces that permit pointers to be passed from either the host
107 // memory space or the device memory space; however, you must instruct it as to
108 // which address space those pointers are in with cublasSetPointerMode.
109 //
110 // This helper sets the cuBLAS pointer mode to a desired value for a cuBLAS call
111 // you are about to perform in a given scope.
112 //
113 // The prior cuBLAS pointer mode is retained and restored when this object goes
114 // out of scope.
115 class ScopedCublasPointerMode {
116 public:
117 // Note that, because the setting of the cublas pointer mode is fallible,
118 // construction of this scoped datatype must be paired with a call to
119 // Init().
120 //
121 // Parameters:
122 // handle: The cublas library handle to act upon in setting the pointer mode.
ScopedCublasPointerMode(cublasHandle_t handle)123 explicit ScopedCublasPointerMode(cublasHandle_t handle)
124 : handle_(handle), ok_(false) {}
125
126 // Attempts the switch to the requested scoped pointer mode, new_mode.
127 //
128 // Note that when false is returned, an appropriate error has already been
129 // logged.
Init(cublasPointerMode_t new_mode)130 bool Init(cublasPointerMode_t new_mode) {
131 cublasStatus_t ret = cublasGetPointerMode(handle_, &old_mode_);
132 if (ret != CUBLAS_STATUS_SUCCESS) {
133 LOG(ERROR) << "failed to get old cublas pointer mode: " << ToString(ret);
134 return ok_ = false;
135 }
136
137 ret = cublasSetPointerMode(handle_, new_mode);
138 if (ret != CUBLAS_STATUS_SUCCESS) {
139 LOG(ERROR) << "failed to set new cublas pointer mode: " << ToString(ret);
140 return ok_ = false;
141 }
142
143 return ok_ = true;
144 }
145
146 // Switches back to the prior pointer mode, if the switch operation was
147 // successful in the first place.
~ScopedCublasPointerMode()148 ~ScopedCublasPointerMode() {
149 if (ok_) {
150 cublasStatus_t ret = cublasSetPointerMode(handle_, old_mode_);
151 if (ret != CUBLAS_STATUS_SUCCESS) {
152 LOG(ERROR) << "failed to set former cublas pointer mode: "
153 << ToString(ret);
154 }
155 }
156 }
157
158 private:
159 cublasHandle_t handle_; // Handle to the cuBLAS instance of interest.
160 cublasPointerMode_t old_mode_; // Prior cuBLAS pointer mode, to be restored.
161 bool ok_; // Whether the change was successful.
162 };
163
164 #if CUDA_VERSION >= 9000
165 // cuBLAS has interfaces that permit computations to use the Volta hardware.
166 // This must be enabled via the cublasGet/SetMathMode APIs.
167 //
168 // This helper sets the cuBLAS math mode to a desired value for a cuBLAS call
169 // you are about to perform in a given scope.
170 //
171 // The prior cuBLAS math mode is retained and restored when this object goes
172 // out of scope.
173 class ScopedCublasMathMode {
174 public:
175 // Note that, because the setting of the cublas math mode is fallible,
176 // construction of this scoped datatype must be paired with a call to
177 // Init().
178 //
179 // Parameters:
180 // handle: The cublas library handle to act upon in setting the math mode.
ScopedCublasMathMode(cublasHandle_t handle)181 explicit ScopedCublasMathMode(cublasHandle_t handle)
182 : handle_(handle), ok_(false) {}
183
184 // Attempts the switch to the requested scoped math mode, new_mode.
185 //
186 // Note that when false is returned, an appropriate error has already been
187 // logged.
Init(cublasMath_t new_mode)188 bool Init(cublasMath_t new_mode) {
189 cublasStatus_t ret = cublasGetMathMode(handle_, &old_mode_);
190 if (ret != CUBLAS_STATUS_SUCCESS) {
191 LOG(ERROR) << "failed to get old cublas math mode: " << ToString(ret);
192 return ok_ = false;
193 }
194
195 ret = cublasSetMathMode(handle_, new_mode);
196 if (ret != CUBLAS_STATUS_SUCCESS) {
197 LOG(ERROR) << "failed to set new cublas math mode: " << ToString(ret);
198 return ok_ = false;
199 }
200 return ok_ = true;
201 }
202
203 // Switches back to the prior math mode, if the switch operation was
204 // successful in the first place.
~ScopedCublasMathMode()205 ~ScopedCublasMathMode() {
206 if (ok_) {
207 cublasStatus_t ret = cublasSetMathMode(handle_, old_mode_);
208 if (ret != CUBLAS_STATUS_SUCCESS) {
209 LOG(ERROR) << "failed to set former cublas math mode: "
210 << ToString(ret);
211 }
212 }
213 }
214
215 private:
216 cublasHandle_t handle_; // Handle to the cuBLAS instance of interest.
217 cublasMath_t old_mode_; // Prior cuBLAS math mode, to be restored.
218 bool ok_; // Whether the change was successful.
219 };
220 #endif // CUDA_VERSION >= 9000
221
Init()222 bool CUDABlas::Init() {
223 gpu::ScopedActivateExecutorContext sac{parent_};
224 cublasStatus_t ret = cublasCreate(&blas_);
225 if (ret != CUBLAS_STATUS_SUCCESS) {
226 LOG(ERROR) << "failed to create cublas handle: " << ToString(ret);
227 return false;
228 }
229
230 #if CUDA_VERSION >= 11000
231 ret = cublasLtCreate(&blasLt_);
232 if (ret != CUBLAS_STATUS_SUCCESS) {
233 LOG(ERROR) << "failed to create cublasLt handle: " << ToString(ret);
234 return false;
235 }
236 #endif // CUDA_VERSION >= 11000
237
238 return true;
239 }
240
CUDABlas(gpu::GpuExecutor * parent)241 CUDABlas::CUDABlas(gpu::GpuExecutor *parent)
242 : parent_(CHECK_NOTNULL(parent)),
243 blas_(nullptr)
244 #if CUDA_VERSION >= 11000
245 ,
246 blasLt_(nullptr)
247 #endif
248 {
249 }
250
~CUDABlas()251 CUDABlas::~CUDABlas() {
252 if (blas_ != nullptr) {
253 gpu::ScopedActivateExecutorContext sac{parent_};
254 cublasDestroy(blas_);
255 }
256 #if CUDA_VERSION >= 11000
257 if (blasLt_ != nullptr) {
258 gpu::ScopedActivateExecutorContext sac{parent_};
259 cublasLtDestroy(blasLt_);
260 }
261 #endif
262 }
263
SetStream(Stream * stream)264 bool CUDABlas::SetStream(Stream *stream) {
265 CHECK(stream != nullptr);
266 CHECK(AsGpuStreamValue(stream) != nullptr);
267 CHECK(blas_ != nullptr);
268 gpu::ScopedActivateExecutorContext sac{parent_};
269 cublasStatus_t ret = cublasSetStream(blas_, AsGpuStreamValue(stream));
270 if (ret != CUBLAS_STATUS_SUCCESS) {
271 LOG(ERROR) << "failed to set stream for cuBLAS calls: " << ToString(ret);
272 return false;
273 }
274
275 return true;
276 }
277
CUDAStream(Stream * stream)278 cudaStream_t CUDABlas::CUDAStream(Stream *stream) {
279 CHECK(stream != nullptr);
280 CHECK(AsGpuStreamValue(stream) != nullptr);
281 gpu::ScopedActivateExecutorContext sac{parent_};
282 return AsGpuStreamValue(stream);
283 }
284
285 namespace {
286
287 // Helper functions transforming blas arguments into cuBLAS arguments.
288
CUDABlasTranspose(blas::Transpose trans)289 cublasOperation_t CUDABlasTranspose(blas::Transpose trans) {
290 switch (trans) {
291 case blas::Transpose::kNoTranspose:
292 return CUBLAS_OP_N;
293 case blas::Transpose::kTranspose:
294 return CUBLAS_OP_T;
295 case blas::Transpose::kConjugateTranspose:
296 return CUBLAS_OP_C;
297 default:
298 LOG(FATAL) << "Invalid value of blas::Transpose.";
299 }
300 }
301
CUDABlasUpperLower(blas::UpperLower uplo)302 cublasFillMode_t CUDABlasUpperLower(blas::UpperLower uplo) {
303 switch (uplo) {
304 case blas::UpperLower::kUpper:
305 return CUBLAS_FILL_MODE_UPPER;
306 case blas::UpperLower::kLower:
307 return CUBLAS_FILL_MODE_LOWER;
308 default:
309 LOG(FATAL) << "Invalid value of blas::UpperLower.";
310 }
311 }
312
CUDABlasDiagonal(blas::Diagonal diag)313 cublasDiagType_t CUDABlasDiagonal(blas::Diagonal diag) {
314 switch (diag) {
315 case blas::Diagonal::kUnit:
316 return CUBLAS_DIAG_UNIT;
317 case blas::Diagonal::kNonUnit:
318 return CUBLAS_DIAG_NON_UNIT;
319 default:
320 LOG(FATAL) << "Invalid value of blas::Diagonal.";
321 }
322 }
323
CUDABlasSide(blas::Side side)324 cublasSideMode_t CUDABlasSide(blas::Side side) {
325 switch (side) {
326 case blas::Side::kLeft:
327 return CUBLAS_SIDE_LEFT;
328 case blas::Side::kRight:
329 return CUBLAS_SIDE_RIGHT;
330 default:
331 LOG(FATAL) << "Invalid value of blas::Side.";
332 }
333 }
334
335 // CUDADataType<T>::type translates from a C++ type (e.g. float) to a
336 // cudaDataType_t (e.g. CUDA_R_32F). CUDAComputationType(ty) translates from a
337 // blas::ComputationType to a cudaDataType_t.
338 //
339 // These are used to build the argument type and computation type args to
340 // cublasGemmEx.
341 template <typename T>
342 struct CUDADataType;
343
344 template <>
345 struct CUDADataType<Eigen::half> {
346 static constexpr cudaDataType_t type = SE_CUDA_DATA_HALF;
347 };
348
349 template <>
350 struct CUDADataType<std::complex<Eigen::half>> {
351 static constexpr cudaDataType_t type = CUDA_C_16F;
352 };
353
354 template <>
355 struct CUDADataType<float> {
356 static constexpr cudaDataType_t type = CUDA_R_32F;
357 };
358
359 template <>
360 struct CUDADataType<std::complex<float>> {
361 static constexpr cudaDataType_t type = CUDA_C_32F;
362 };
363
364 template <>
365 struct CUDADataType<double> {
366 static constexpr cudaDataType_t type = CUDA_R_64F;
367 };
368
369 template <>
370 struct CUDADataType<std::complex<double>> {
371 static constexpr cudaDataType_t type = CUDA_C_64F;
372 };
373
374 template <>
375 struct CUDADataType<int> {
376 static constexpr cudaDataType_t type = CUDA_R_32I;
377 };
378
379 template <>
380 struct CUDADataType<int8> {
381 static constexpr cudaDataType_t type = CUDA_R_8I;
382 };
383
384 template <>
385 struct CUDADataType<std::complex<int8>> {
386 static constexpr cudaDataType_t type = CUDA_C_8I;
387 };
388
389 template <>
390 struct CUDADataType<uint8> {
391 static constexpr cudaDataType_t type = CUDA_R_8U;
392 };
393
394 template <>
395 struct CUDADataType<std::complex<uint8>> {
396 static constexpr cudaDataType_t type = CUDA_C_8U;
397 };
398
CUDAComputationType(blas::ComputationType ty)399 cudaDataType_t CUDAComputationType(blas::ComputationType ty) {
400 switch (ty) {
401 case blas::ComputationType::kF16:
402 return CUDA_R_16F;
403 case blas::ComputationType::kF32:
404 return CUDA_R_32F;
405 case blas::ComputationType::kF64:
406 return CUDA_R_64F;
407 case blas::ComputationType::kI32:
408 return CUDA_R_32I;
409 case blas::ComputationType::kComplexF32:
410 return CUDA_C_32F;
411 case blas::ComputationType::kComplexF64:
412 return CUDA_C_64F;
413 case blas::ComputationType::kTF32AsF32: // fall-through
414 case blas::ComputationType::kBF16AsF32:
415 // These cases are currently only supported in the blasLt routines, which
416 // use CUBLASComputationType() instead.
417 LOG(FATAL) << "Invalid value of blas::ComputationType.";
418 }
419 }
420
421 #if CUDA_VERSION >= 11000
CUBLASComputationType(blas::ComputationType ty)422 cublasComputeType_t CUBLASComputationType(blas::ComputationType ty) {
423 switch (ty) {
424 case blas::ComputationType::kF16:
425 return CUBLAS_COMPUTE_16F;
426 case blas::ComputationType::kF32: // fall-through
427 case blas::ComputationType::kComplexF32:
428 return CUBLAS_COMPUTE_32F;
429 case blas::ComputationType::kF64: // fall-through
430 case blas::ComputationType::kComplexF64:
431 return CUBLAS_COMPUTE_64F;
432 case blas::ComputationType::kI32:
433 return CUBLAS_COMPUTE_32I;
434 case blas::ComputationType::kTF32AsF32:
435 return CUBLAS_COMPUTE_32F_FAST_TF32;
436 case blas::ComputationType::kBF16AsF32:
437 return CUBLAS_COMPUTE_32F_FAST_16BF;
438 }
439 }
440 #endif // CUDA_VERSION >= 11000
441
GetScaleType(blas::DataType data_type,blas::ComputationType compute_type)442 blas::DataType GetScaleType(blas::DataType data_type,
443 blas::ComputationType compute_type) {
444 bool is_complex = data_type == blas::DataType::kComplexFloat ||
445 data_type == blas::DataType::kComplexDouble;
446 switch (compute_type) {
447 case blas::ComputationType::kF16:
448 return blas::DataType::kHalf;
449 case blas::ComputationType::kF32: // fall-through
450 case blas::ComputationType::kComplexF32: // fall-through
451 case blas::ComputationType::kTF32AsF32: // fall-through
452 case blas::ComputationType::kBF16AsF32:
453 return is_complex ? blas::DataType::kComplexFloat
454 : blas::DataType::kFloat;
455 case blas::ComputationType::kF64: // fall-through
456 case blas::ComputationType::kComplexF64:
457 return is_complex ? blas::DataType::kComplexDouble
458 : blas::DataType::kDouble;
459 case blas::ComputationType::kI32:
460 return blas::DataType::kInt32;
461 }
462 }
463
464 #if CUDA_VERSION >= 11000
CUBLASPointerMode(blas::PointerMode pointer_mode)465 cublasLtPointerMode_t CUBLASPointerMode(blas::PointerMode pointer_mode) {
466 switch (pointer_mode) {
467 case blas::PointerMode::kHost:
468 return CUBLASLT_POINTER_MODE_HOST;
469 case blas::PointerMode::kDevice:
470 return CUBLASLT_POINTER_MODE_DEVICE;
471 }
472 }
CUBLASEpilogue(blas::Epilogue epilogue)473 cublasLtEpilogue_t CUBLASEpilogue(blas::Epilogue epilogue) {
474 switch (epilogue) {
475 case blas::Epilogue::kDefault:
476 return CUBLASLT_EPILOGUE_DEFAULT;
477 case blas::Epilogue::kReLU:
478 return CUBLASLT_EPILOGUE_RELU;
479 case blas::Epilogue::kBias:
480 return CUBLASLT_EPILOGUE_BIAS;
481 case blas::Epilogue::kBiasThenReLU:
482 return CUBLASLT_EPILOGUE_RELU_BIAS;
483 }
484 }
485 #endif // CUDA_VERSION >= 11000
486
GetCUDADataType(blas::DataType ty)487 cudaDataType_t GetCUDADataType(blas::DataType ty) {
488 switch (ty) {
489 case blas::DataType::kHalf:
490 return CUDA_R_16F;
491 case blas::DataType::kFloat:
492 return CUDA_R_32F;
493 case blas::DataType::kDouble:
494 return CUDA_R_64F;
495 case blas::DataType::kInt8:
496 return CUDA_R_8I;
497 case blas::DataType::kInt32:
498 return CUDA_R_32I;
499 case blas::DataType::kComplexFloat:
500 return CUDA_C_32F;
501 case blas::DataType::kComplexDouble:
502 return CUDA_C_64F;
503 default:
504 LOG(FATAL) << "Invalid value of blas::DataType in GetCUDADataType";
505 }
506 }
507
GetDataTypeSizeBytes(blas::DataType ty)508 int GetDataTypeSizeBytes(blas::DataType ty) {
509 switch (ty) {
510 case blas::DataType::kHalf:
511 return 2;
512 case blas::DataType::kFloat:
513 return 4;
514 case blas::DataType::kDouble:
515 return 8;
516 case blas::DataType::kInt8:
517 return 1;
518 case blas::DataType::kInt32:
519 return 4;
520 case blas::DataType::kComplexFloat:
521 return 8;
522 case blas::DataType::kComplexDouble:
523 return 16;
524 default:
525 LOG(FATAL) << "Invalid value of blas::DataType in GetDataTypeSizeBytes";
526 }
527 }
528
529 } // namespace
530
531 template <typename FuncT, typename... Args>
DoBlasInternalImpl(FuncT cublas_func,Stream * stream,bool pointer_mode_host,bool err_on_failure,cublasMath_t math_type,Args...args)532 bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
533 bool pointer_mode_host, bool err_on_failure,
534 cublasMath_t math_type, Args... args) {
535 absl::MutexLock lock(&mu_);
536
537 CHECK(blas_ != nullptr);
538 if (!SetStream(stream)) {
539 return false;
540 }
541
542 #if CUDA_VERSION >= 9000
543 ScopedCublasMathMode math_mode{blas_};
544 #if CUBLAS_VER_MAJOR >= 11
545 if (math_type == CUBLAS_TF32_TENSOR_OP_MATH &&
546 tensorflow::tensor_float_32_execution_enabled()) {
547 #else
548 if (math_type == CUBLAS_TENSOR_OP_MATH) {
549 #endif
550 if (!math_mode.Init(math_type)) {
551 return false;
552 }
553 }
554 #endif
555
556 gpu::ScopedActivateExecutorContext sac{parent_};
557 ScopedCublasPointerMode pointer_mode{blas_};
558 if (!pointer_mode.Init(pointer_mode_host ? CUBLAS_POINTER_MODE_HOST
559 : CUBLAS_POINTER_MODE_DEVICE)) {
560 return false;
561 }
562 cublasStatus_t ret = cublas_func(blas_, args...);
563 if ((err_on_failure || VLOG_IS_ON(3)) && ret != CUBLAS_STATUS_SUCCESS) {
564 LOG(ERROR) << "failed to run cuBLAS routine: " << ToString(ret);
565 }
566 return ret == CUBLAS_STATUS_SUCCESS;
567 }
568
569 // cublas_func may be overloaded, so we need to figure out which one we really
570 // need to call based on the args. One way to do it is to wrap it in lambda.
571 #define AS_LAMBDA(func) \
572 [](auto &&... args) -> decltype( \
573 func(std::forward<decltype(args)>(args)...)) { \
574 return func(std::forward<decltype(args)>(args)...); \
575 }
576
577 bool CUDABlas::DoBlasAsum(Stream *stream, uint64 elem_count,
578 const DeviceMemory<float> &x, int incx,
579 DeviceMemory<float> *result) {
580 return DoBlasInternal(cublasSasum, stream, false /* = pointer_mode_host */,
581 elem_count, GpuMemory(x), incx,
582 GpuMemoryMutable(result));
583 }
584
585 bool CUDABlas::DoBlasAsum(Stream *stream, uint64 elem_count,
586 const DeviceMemory<double> &x, int incx,
587 DeviceMemory<double> *result) {
588 return DoBlasInternal(cublasDasum, stream, false /* = pointer_mode_host */,
589 elem_count, GpuMemory(x), incx,
590 GpuMemoryMutable(result));
591 }
592
593 bool CUDABlas::DoBlasAsum(Stream *stream, uint64 elem_count,
594 const DeviceMemory<std::complex<float>> &x, int incx,
595 DeviceMemory<float> *result) {
596 return DoBlasInternal(cublasScasum, stream, false /* = pointer_mode_host */,
597 elem_count, GpuComplex(GpuMemory(x)), incx,
598 GpuMemoryMutable(result));
599 }
600
601 bool CUDABlas::DoBlasAsum(Stream *stream, uint64 elem_count,
602 const DeviceMemory<std::complex<double>> &x, int incx,
603 DeviceMemory<double> *result) {
604 return DoBlasInternal(cublasDzasum, stream, false /* = pointer_mode_host */,
605 elem_count, GpuComplex(GpuMemory(x)), incx,
606 GpuMemoryMutable(result));
607 }
608
609 bool CUDABlas::DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha,
610 const DeviceMemory<float> &x, int incx,
611 DeviceMemory<float> *y, int incy) {
612 return DoBlasInternal(cublasSaxpy, stream, true /* = pointer_mode_host */,
613 elem_count, &alpha, GpuMemory(x), incx,
614 GpuMemoryMutable(y), incy);
615 }
616
617 bool CUDABlas::DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha,
618 const DeviceMemory<double> &x, int incx,
619 DeviceMemory<double> *y, int incy) {
620 return DoBlasInternal(cublasDaxpy, stream, true /* = pointer_mode_host */,
621 elem_count, &alpha, GpuMemory(x), incx,
622 GpuMemoryMutable(y), incy);
623 }
624
625 bool CUDABlas::DoBlasAxpy(Stream *stream, uint64 elem_count,
626 std::complex<float> alpha,
627 const DeviceMemory<std::complex<float>> &x, int incx,
628 DeviceMemory<std::complex<float>> *y, int incy) {
629 auto cb_alpha = GpuComplexValue(alpha);
630 return DoBlasInternal(cublasCaxpy, stream, true /* = pointer_mode_host */,
631 elem_count, GpuComplex(&cb_alpha),
632 GpuComplex(GpuMemory(x)), incx,
633 GpuComplex(GpuMemoryMutable(y)), incy);
634 }
635
636 bool CUDABlas::DoBlasAxpy(Stream *stream, uint64 elem_count,
637 std::complex<double> alpha,
638 const DeviceMemory<std::complex<double>> &x, int incx,
639 DeviceMemory<std::complex<double>> *y, int incy) {
640 auto cb_alpha = GpuComplexValue(alpha);
641 return DoBlasInternal(cublasZaxpy, stream, true /* = pointer_mode_host */,
642 elem_count, GpuComplex(&cb_alpha),
643 GpuComplex(GpuMemory(x)), incx,
644 GpuComplex(GpuMemoryMutable(y)), incy);
645 }
646
647 bool CUDABlas::DoBlasCopy(Stream *stream, uint64 elem_count,
648 const DeviceMemory<float> &x, int incx,
649 DeviceMemory<float> *y, int incy) {
650 return DoBlasInternal(cublasScopy, stream, true /* = pointer_mode_host */,
651 elem_count, GpuMemory(x), incx, GpuMemoryMutable(y),
652 incy);
653 }
654
655 bool CUDABlas::DoBlasCopy(Stream *stream, uint64 elem_count,
656 const DeviceMemory<double> &x, int incx,
657 DeviceMemory<double> *y, int incy) {
658 return DoBlasInternal(cublasDcopy, stream, true /* = pointer_mode_host */,
659 elem_count, GpuMemory(x), incx, GpuMemoryMutable(y),
660 incy);
661 }
662
663 bool CUDABlas::DoBlasCopy(Stream *stream, uint64 elem_count,
664 const DeviceMemory<std::complex<float>> &x, int incx,
665 DeviceMemory<std::complex<float>> *y, int incy) {
666 return DoBlasInternal(cublasCcopy, stream, true /* = pointer_mode_host */,
667 elem_count, GpuComplex(GpuMemory(x)), incx,
668 GpuComplex(GpuMemoryMutable(y)), incy);
669 }
670
671 bool CUDABlas::DoBlasCopy(Stream *stream, uint64 elem_count,
672 const DeviceMemory<std::complex<double>> &x, int incx,
673 DeviceMemory<std::complex<double>> *y, int incy) {
674 return DoBlasInternal(cublasZcopy, stream, true /* = pointer_mode_host */,
675 elem_count, GpuComplex(GpuMemory(x)), incx,
676 GpuComplex(GpuMemoryMutable(y)), incy);
677 }
678
679 bool CUDABlas::DoBlasDot(Stream *stream, uint64 elem_count,
680 const DeviceMemory<float> &x, int incx,
681 const DeviceMemory<float> &y, int incy,
682 DeviceMemory<float> *result) {
683 return DoBlasInternal(cublasSdot, stream, false /* = pointer_mode_host */,
684 elem_count, GpuMemory(x), incx, GpuMemory(y), incy,
685 GpuMemoryMutable(result));
686 }
687
688 bool CUDABlas::DoBlasDot(Stream *stream, uint64 elem_count,
689 const DeviceMemory<double> &x, int incx,
690 const DeviceMemory<double> &y, int incy,
691 DeviceMemory<double> *result) {
692 return DoBlasInternal(cublasDdot, stream, false /* = pointer_mode_host */,
693 elem_count, GpuMemory(x), incx, GpuMemory(y), incy,
694 GpuMemoryMutable(result));
695 }
696
697 bool CUDABlas::DoBlasDotc(Stream *stream, uint64 elem_count,
698 const DeviceMemory<std::complex<float>> &x, int incx,
699 const DeviceMemory<std::complex<float>> &y, int incy,
700 DeviceMemory<std::complex<float>> *result) {
701 return DoBlasInternal(cublasCdotc, stream, false /* = pointer_mode_host */,
702 elem_count, GpuComplex(GpuMemory(x)), incx,
703 GpuComplex(GpuMemory(y)), incy,
704 GpuComplex(GpuMemoryMutable(result)));
705 }
706
707 bool CUDABlas::DoBlasDotc(Stream *stream, uint64 elem_count,
708 const DeviceMemory<std::complex<double>> &x, int incx,
709 const DeviceMemory<std::complex<double>> &y, int incy,
710 DeviceMemory<std::complex<double>> *result) {
711 return DoBlasInternal(cublasZdotc, stream, false /* = pointer_mode_host */,
712 elem_count, GpuComplex(GpuMemory(x)), incx,
713 GpuComplex(GpuMemory(y)), incy,
714 GpuComplex(GpuMemoryMutable(result)));
715 }
716
717 bool CUDABlas::DoBlasDotu(Stream *stream, uint64 elem_count,
718 const DeviceMemory<std::complex<float>> &x, int incx,
719 const DeviceMemory<std::complex<float>> &y, int incy,
720 DeviceMemory<std::complex<float>> *result) {
721 return DoBlasInternal(cublasCdotu, stream, false /* = pointer_mode_host */,
722 elem_count, GpuComplex(GpuMemory(x)), incx,
723 GpuComplex(GpuMemory(y)), incy,
724 GpuComplex(GpuMemoryMutable(result)));
725 }
726
727 bool CUDABlas::DoBlasDotu(Stream *stream, uint64 elem_count,
728 const DeviceMemory<std::complex<double>> &x, int incx,
729 const DeviceMemory<std::complex<double>> &y, int incy,
730 DeviceMemory<std::complex<double>> *result) {
731 return DoBlasInternal(cublasZdotu, stream, false /* = pointer_mode_host */,
732 elem_count, GpuComplex(GpuMemory(x)), incx,
733 GpuComplex(GpuMemory(y)), incy,
734 GpuComplex(GpuMemoryMutable(result)));
735 }
736
737 bool CUDABlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
738 const DeviceMemory<float> &x, int incx,
739 DeviceMemory<float> *result) {
740 return DoBlasInternal(cublasSnrm2, stream, false /* = pointer_mode_host */,
741 elem_count, GpuMemory(x), incx,
742 GpuMemoryMutable(result));
743 }
744
745 bool CUDABlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
746 const DeviceMemory<double> &x, int incx,
747 DeviceMemory<double> *result) {
748 return DoBlasInternal(cublasDnrm2, stream, false /* = pointer_mode_host */,
749 elem_count, GpuMemory(x), incx,
750 GpuMemoryMutable(result));
751 }
752
753 bool CUDABlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
754 const DeviceMemory<std::complex<float>> &x, int incx,
755 DeviceMemory<float> *result) {
756 return DoBlasInternal(cublasScnrm2, stream, false /* = pointer_mode_host */,
757 elem_count, GpuComplex(GpuMemory(x)), incx,
758 GpuMemoryMutable(result));
759 }
760
761 bool CUDABlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
762 const DeviceMemory<std::complex<double>> &x, int incx,
763 DeviceMemory<double> *result) {
764 return DoBlasInternal(cublasDznrm2, stream, false /* = pointer_mode_host */,
765 elem_count, GpuComplex(GpuMemory(x)), incx,
766 GpuMemoryMutable(result));
767 }
768
769 bool CUDABlas::DoBlasRot(Stream *stream, uint64 elem_count,
770 DeviceMemory<float> *x, int incx,
771 DeviceMemory<float> *y, int incy, float c, float s) {
772 return DoBlasInternal(cublasSrot, stream, true /* = pointer_mode_host */,
773 elem_count, GpuMemoryMutable(x), incx,
774 GpuMemoryMutable(y), incy, &c, &s);
775 }
776
777 bool CUDABlas::DoBlasRot(Stream *stream, uint64 elem_count,
778 DeviceMemory<double> *x, int incx,
779 DeviceMemory<double> *y, int incy, double c,
780 double s) {
781 return DoBlasInternal(cublasDrot, stream, true /* = pointer_mode_host */,
782 elem_count, GpuMemoryMutable(x), incx,
783 GpuMemoryMutable(y), incy, &c, &s);
784 }
785
786 bool CUDABlas::DoBlasRot(Stream *stream, uint64 elem_count,
787 DeviceMemory<std::complex<float>> *x, int incx,
788 DeviceMemory<std::complex<float>> *y, int incy,
789 float c, float s) {
790 return DoBlasInternal(cublasCsrot, stream, true /* = pointer_mode_host */,
791 elem_count, GpuComplex(GpuMemoryMutable(x)), incx,
792 GpuComplex(GpuMemoryMutable(y)), incy, &c, &s);
793 }
794
795 bool CUDABlas::DoBlasRot(Stream *stream, uint64 elem_count,
796 DeviceMemory<std::complex<double>> *x, int incx,
797 DeviceMemory<std::complex<double>> *y, int incy,
798 double c, double s) {
799 return DoBlasInternal(cublasZdrot, stream, true /* = pointer_mode_host */,
800 elem_count, GpuComplex(GpuMemoryMutable(x)), incx,
801 GpuComplex(GpuMemoryMutable(y)), incy, &c, &s);
802 }
803
804 bool CUDABlas::DoBlasRotg(Stream *stream, DeviceMemory<float> *a,
805 DeviceMemory<float> *b, DeviceMemory<float> *c,
806 DeviceMemory<float> *s) {
807 return DoBlasInternal(cublasSrotg, stream, false /* = pointer_mode_host */,
808 GpuMemoryMutable(a), GpuMemoryMutable(b),
809 GpuMemoryMutable(c), GpuMemoryMutable(s));
810 }
811
812 bool CUDABlas::DoBlasRotg(Stream *stream, DeviceMemory<double> *a,
813 DeviceMemory<double> *b, DeviceMemory<double> *c,
814 DeviceMemory<double> *s) {
815 return DoBlasInternal(cublasDrotg, stream, false /* = pointer_mode_host */,
816 GpuComplex(GpuMemoryMutable(a)), GpuMemoryMutable(b),
817 GpuMemoryMutable(c), GpuMemoryMutable(s));
818 }
819
820 bool CUDABlas::DoBlasRotg(Stream *stream, DeviceMemory<std::complex<float>> *a,
821 DeviceMemory<std::complex<float>> *b,
822 DeviceMemory<float> *c,
823 DeviceMemory<std::complex<float>> *s) {
824 return DoBlasInternal(
825 cublasCrotg, stream, false /* = pointer_mode_host */,
826 GpuComplex(GpuMemoryMutable(a)), GpuComplex(GpuMemoryMutable(b)),
827 GpuComplex(GpuMemoryMutable(c)), GpuComplex(GpuMemoryMutable(s)));
828 }
829
830 bool CUDABlas::DoBlasRotg(Stream *stream, DeviceMemory<std::complex<double>> *a,
831 DeviceMemory<std::complex<double>> *b,
832 DeviceMemory<double> *c,
833 DeviceMemory<std::complex<double>> *s) {
834 return DoBlasInternal(
835 cublasZrotg, stream, false /* = pointer_mode_host */,
836 GpuComplex(GpuMemoryMutable(a)), GpuComplex(GpuMemoryMutable(b)),
837 GpuComplex(GpuMemoryMutable(c)), GpuComplex(GpuMemoryMutable(s)));
838 }
839
840 bool CUDABlas::DoBlasRotm(Stream *stream, uint64 elem_count,
841 DeviceMemory<float> *x, int incx,
842 DeviceMemory<float> *y, int incy,
843 const DeviceMemory<float> ¶m) {
844 return DoBlasInternal(cublasSrotm, stream, false /* = pointer_mode_host */,
845 elem_count, GpuMemoryMutable(x), incx,
846 GpuMemoryMutable(y), incy, GpuMemory(param));
847 }
848
849 bool CUDABlas::DoBlasRotm(Stream *stream, uint64 elem_count,
850 DeviceMemory<double> *x, int incx,
851 DeviceMemory<double> *y, int incy,
852 const DeviceMemory<double> ¶m) {
853 return DoBlasInternal(cublasDrotm, stream, false /* = pointer_mode_host */,
854 elem_count, GpuMemoryMutable(x), incx,
855 GpuMemoryMutable(y), incy, GpuMemory(param));
856 }
857
858 bool CUDABlas::DoBlasRotmg(Stream *stream, DeviceMemory<float> *d1,
859 DeviceMemory<float> *d2, DeviceMemory<float> *x1,
860 const DeviceMemory<float> &y1,
861 DeviceMemory<float> *param) {
862 return DoBlasInternal(cublasSrotmg, stream, false /* = pointer_mode_host */,
863 GpuMemoryMutable(d1), GpuMemoryMutable(d2),
864 GpuMemoryMutable(x1), GpuMemory(y1),
865 GpuMemoryMutable(param));
866 }
867
868 bool CUDABlas::DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1,
869 DeviceMemory<double> *d2, DeviceMemory<double> *x1,
870 const DeviceMemory<double> &y1,
871 DeviceMemory<double> *param) {
872 return DoBlasInternal(cublasDrotmg, stream, false /* = pointer_mode_host */,
873 GpuMemoryMutable(d1), GpuMemoryMutable(d2),
874 GpuMemoryMutable(x1), GpuMemory(y1),
875 GpuMemoryMutable(param));
876 }
877
878 bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
879 DeviceMemory<float> *x, int incx) {
880 return DoBlasInternal(cublasSscal, stream, true /* = pointer_mode_host */,
881 elem_count, &alpha, GpuMemoryMutable(x), incx);
882 }
883
884 bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
885 DeviceMemory<double> *x, int incx) {
886 return DoBlasInternal(cublasDscal, stream, true /* = pointer_mode_host */,
887 elem_count, &alpha, GpuMemoryMutable(x), incx);
888 }
889
890 bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
891 DeviceMemory<std::complex<float>> *x, int incx) {
892 return DoBlasInternal(cublasCsscal, stream, true /* = pointer_mode_host */,
893 elem_count, &alpha, GpuComplex(GpuMemoryMutable(x)),
894 incx);
895 }
896
897 bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
898 DeviceMemory<std::complex<double>> *x, int incx) {
899 return DoBlasInternal(cublasZdscal, stream, true /* = pointer_mode_host */,
900 elem_count, &alpha, GpuComplex(GpuMemoryMutable(x)),
901 incx);
902 }
903
904 bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count,
905 std::complex<float> alpha,
906 DeviceMemory<std::complex<float>> *x, int incx) {
907 auto cb_alpha = GpuComplexValue(alpha);
908 return DoBlasInternal(cublasCscal, stream, true /* = pointer_mode_host */,
909 elem_count, GpuComplex(&cb_alpha),
910 GpuComplex(GpuMemoryMutable(x)), incx);
911 }
912
913 bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count,
914 std::complex<double> alpha,
915 DeviceMemory<std::complex<double>> *x, int incx) {
916 auto cb_alpha = GpuComplexValue(alpha);
917 return DoBlasInternal(cublasZscal, stream, true /* = pointer_mode_host */,
918 elem_count, GpuComplex(&cb_alpha),
919 GpuComplex(GpuMemoryMutable(x)), incx);
920 }
921
922 bool CUDABlas::DoBlasSwap(Stream *stream, uint64 elem_count,
923 DeviceMemory<float> *x, int incx,
924 DeviceMemory<float> *y, int incy) {
925 return DoBlasInternal(cublasSswap, stream, true /* = pointer_mode_host */,
926 elem_count, GpuMemoryMutable(x), incx,
927 GpuMemoryMutable(y), incy);
928 }
929
930 bool CUDABlas::DoBlasSwap(Stream *stream, uint64 elem_count,
931 DeviceMemory<double> *x, int incx,
932 DeviceMemory<double> *y, int incy) {
933 return DoBlasInternal(cublasDswap, stream, true /* = pointer_mode_host */,
934 elem_count, GpuMemoryMutable(x), incx,
935 GpuMemoryMutable(y), incy);
936 }
937
938 bool CUDABlas::DoBlasSwap(Stream *stream, uint64 elem_count,
939 DeviceMemory<std::complex<float>> *x, int incx,
940 DeviceMemory<std::complex<float>> *y, int incy) {
941 return DoBlasInternal(cublasCswap, stream, true /* = pointer_mode_host */,
942 elem_count, GpuComplex(GpuMemoryMutable(x)), incx,
943 GpuComplex(GpuMemoryMutable(y)), incy);
944 }
945
946 bool CUDABlas::DoBlasSwap(Stream *stream, uint64 elem_count,
947 DeviceMemory<std::complex<double>> *x, int incx,
948 DeviceMemory<std::complex<double>> *y, int incy) {
949 return DoBlasInternal(cublasZswap, stream, true /* = pointer_mode_host */,
950 elem_count, GpuComplex(GpuMemoryMutable(x)), incx,
951 GpuComplex(GpuMemoryMutable(y)), incy);
952 }
953
954 bool CUDABlas::DoBlasIamax(Stream *stream, uint64 elem_count,
955 const DeviceMemory<float> &x, int incx,
956 DeviceMemory<int> *result) {
957 return DoBlasInternal(cublasIsamax, stream, false /* = pointer_mode_host */,
958 elem_count, GpuMemory(x), incx,
959 GpuMemoryMutable(result));
960 }
961
962 bool CUDABlas::DoBlasIamax(Stream *stream, uint64 elem_count,
963 const DeviceMemory<double> &x, int incx,
964 DeviceMemory<int> *result) {
965 return DoBlasInternal(cublasIdamax, stream, false /* = pointer_mode_host */,
966 elem_count, GpuMemory(x), incx,
967 GpuMemoryMutable(result));
968 }
969
970 bool CUDABlas::DoBlasIamax(Stream *stream, uint64 elem_count,
971 const DeviceMemory<std::complex<float>> &x, int incx,
972 DeviceMemory<int> *result) {
973 return DoBlasInternal(cublasIcamax, stream, false /* = pointer_mode_host */,
974 elem_count, GpuComplex(GpuMemory(x)), incx,
975 GpuMemoryMutable(result));
976 }
977
978 bool CUDABlas::DoBlasIamax(Stream *stream, uint64 elem_count,
979 const DeviceMemory<std::complex<double>> &x,
980 int incx, DeviceMemory<int> *result) {
981 return DoBlasInternal(cublasIzamax, stream, false /* = pointer_mode_host */,
982 elem_count, GpuComplex(GpuMemory(x)), incx,
983 GpuMemoryMutable(result));
984 }
985
986 bool CUDABlas::DoBlasIamin(Stream *stream, uint64 elem_count,
987 const DeviceMemory<float> &x, int incx,
988 DeviceMemory<int> *result) {
989 return DoBlasInternal(cublasIsamin, stream, false /* = pointer_mode_host */,
990 elem_count, GpuComplex(GpuMemory(x)), incx,
991 GpuMemoryMutable(result));
992 }
993
994 bool CUDABlas::DoBlasIamin(Stream *stream, uint64 elem_count,
995 const DeviceMemory<double> &x, int incx,
996 DeviceMemory<int> *result) {
997 return DoBlasInternal(cublasIdamin, stream, false /* = pointer_mode_host */,
998 elem_count, GpuComplex(GpuMemory(x)), incx,
999 GpuMemoryMutable(result));
1000 }
1001
1002 bool CUDABlas::DoBlasIamin(Stream *stream, uint64 elem_count,
1003 const DeviceMemory<std::complex<float>> &x, int incx,
1004 DeviceMemory<int> *result) {
1005 return DoBlasInternal(cublasIcamin, stream, false /* = pointer_mode_host */,
1006 elem_count, GpuComplex(GpuMemory(x)), incx,
1007 GpuMemoryMutable(result));
1008 }
1009
1010 bool CUDABlas::DoBlasIamin(Stream *stream, uint64 elem_count,
1011 const DeviceMemory<std::complex<double>> &x,
1012 int incx, DeviceMemory<int> *result) {
1013 return DoBlasInternal(cublasIzamin, stream, false /* = pointer_mode_host */,
1014 elem_count, GpuComplex(GpuMemory(x)), incx,
1015 GpuMemoryMutable(result));
1016 }
1017
1018 bool CUDABlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
1019 uint64 n, uint64 kl, uint64 ku, float alpha,
1020 const DeviceMemory<float> &a, int lda,
1021 const DeviceMemory<float> &x, int incx, float beta,
1022 DeviceMemory<float> *y, int incy) {
1023 return DoBlasInternal(cublasSgbmv, stream, true /* = pointer_mode_host */,
1024 CUDABlasTranspose(trans), m, n, kl, ku, &alpha,
1025 GpuMemory(a), lda, GpuMemory(x), incx, &beta,
1026 GpuMemoryMutable(y), incy);
1027 }
1028
1029 bool CUDABlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
1030 uint64 n, uint64 kl, uint64 ku, double alpha,
1031 const DeviceMemory<double> &a, int lda,
1032 const DeviceMemory<double> &x, int incx, double beta,
1033 DeviceMemory<double> *y, int incy) {
1034 return DoBlasInternal(cublasDgbmv, stream, true /* = pointer_mode_host */,
1035 CUDABlasTranspose(trans), m, n, kl, ku, &alpha,
1036 GpuMemory(a), lda, GpuMemory(x), incx, &beta,
1037 GpuMemoryMutable(y), incy);
1038 }
1039
1040 bool CUDABlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
1041 uint64 n, uint64 kl, uint64 ku,
1042 std::complex<float> alpha,
1043 const DeviceMemory<std::complex<float>> &a, int lda,
1044 const DeviceMemory<std::complex<float>> &x, int incx,
1045 std::complex<float> beta,
1046 DeviceMemory<std::complex<float>> *y, int incy) {
1047 auto cb_alpha = GpuComplexValue(alpha);
1048 auto cb_beta = GpuComplexValue(beta);
1049 return DoBlasInternal(cublasCgbmv, stream, true /* = pointer_mode_host */,
1050 CUDABlasTranspose(trans), m, n, kl, ku,
1051 GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
1052 GpuComplex(GpuMemory(x)), incx, GpuComplex(&cb_beta),
1053 GpuComplex(GpuMemoryMutable(y)), incy);
1054 }
1055
1056 bool CUDABlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
1057 uint64 n, uint64 kl, uint64 ku,
1058 std::complex<double> alpha,
1059 const DeviceMemory<std::complex<double>> &a, int lda,
1060 const DeviceMemory<std::complex<double>> &x, int incx,
1061 std::complex<double> beta,
1062 DeviceMemory<std::complex<double>> *y, int incy) {
1063 auto cb_alpha = GpuComplexValue(alpha);
1064 auto cb_beta = GpuComplexValue(beta);
1065 return DoBlasInternal(cublasZgbmv, stream, true /* = pointer_mode_host */,
1066 CUDABlasTranspose(trans), m, n, kl, ku,
1067 GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
1068 GpuComplex(GpuMemory(x)), incx, GpuComplex(&cb_beta),
1069 GpuComplex(GpuMemoryMutable(y)), incy);
1070 }
1071
1072 bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
1073 uint64 n, float alpha, const DeviceMemory<float> &a,
1074 int lda, const DeviceMemory<float> &x, int incx,
1075 float beta, DeviceMemory<float> *y, int incy) {
1076 return DoBlasInternal(cublasSgemv, stream, true /* = pointer_mode_host */,
1077 CUDABlasTranspose(trans), m, n, &alpha, GpuMemory(a),
1078 lda, GpuMemory(x), incx, &beta, GpuMemoryMutable(y),
1079 incy);
1080 }
1081
1082 bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
1083 uint64 n, double alpha, const DeviceMemory<double> &a,
1084 int lda, const DeviceMemory<double> &x, int incx,
1085 double beta, DeviceMemory<double> *y, int incy) {
1086 return DoBlasInternal(cublasDgemv, stream, true /* = pointer_mode_host */,
1087 CUDABlasTranspose(trans), m, n, &alpha, GpuMemory(a),
1088 lda, GpuMemory(x), incx, &beta, GpuMemoryMutable(y),
1089 incy);
1090 }
1091
1092 bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
1093 uint64 n, std::complex<float> alpha,
1094 const DeviceMemory<std::complex<float>> &a, int lda,
1095 const DeviceMemory<std::complex<float>> &x, int incx,
1096 std::complex<float> beta,
1097 DeviceMemory<std::complex<float>> *y, int incy) {
1098 auto cb_alpha = GpuComplexValue(alpha);
1099 auto cb_beta = GpuComplexValue(beta);
1100 return DoBlasInternal(cublasCgemv, stream, true /* = pointer_mode_host */,
1101 CUDABlasTranspose(trans), m, n, GpuComplex(&cb_alpha),
1102 GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
1103 incx, GpuComplex(&cb_beta),
1104 GpuComplex(GpuMemoryMutable(y)), incy);
1105 }
1106
1107 bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
1108 uint64 n, std::complex<double> alpha,
1109 const DeviceMemory<std::complex<double>> &a, int lda,
1110 const DeviceMemory<std::complex<double>> &x, int incx,
1111 std::complex<double> beta,
1112 DeviceMemory<std::complex<double>> *y, int incy) {
1113 auto cb_alpha = GpuComplexValue(alpha);
1114 auto cb_beta = GpuComplexValue(beta);
1115 return DoBlasInternal(cublasZgemv, stream, true /* = pointer_mode_host */,
1116 CUDABlasTranspose(trans), m, n, GpuComplex(&cb_alpha),
1117 GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
1118 incx, GpuComplex(&cb_beta),
1119 GpuComplex(GpuMemoryMutable(y)), incy);
1120 }
1121
1122 bool CUDABlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha,
1123 const DeviceMemory<float> &x, int incx,
1124 const DeviceMemory<float> &y, int incy,
1125 DeviceMemory<float> *a, int lda) {
1126 return DoBlasInternal(cublasSger, stream, true /* = pointer_mode_host */, m,
1127 n, &alpha, GpuMemory(x), incx, GpuMemory(y), incy,
1128 GpuMemoryMutable(a), lda);
1129 }
1130
1131 bool CUDABlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha,
1132 const DeviceMemory<double> &x, int incx,
1133 const DeviceMemory<double> &y, int incy,
1134 DeviceMemory<double> *a, int lda) {
1135 return DoBlasInternal(cublasDger, stream, true /* = pointer_mode_host */, m,
1136 n, &alpha, GpuMemory(x), incx, GpuMemory(y), incy,
1137 GpuMemoryMutable(a), lda);
1138 }
1139
1140 bool CUDABlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n,
1141 std::complex<float> alpha,
1142 const DeviceMemory<std::complex<float>> &x, int incx,
1143 const DeviceMemory<std::complex<float>> &y, int incy,
1144 DeviceMemory<std::complex<float>> *a, int lda) {
1145 auto cb_alpha = GpuComplexValue(alpha);
1146 return DoBlasInternal(cublasCgerc, stream, true /* = pointer_mode_host */, m,
1147 n, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(x)),
1148 incx, GpuComplex(GpuMemory(y)), incy,
1149 GpuComplex(GpuMemoryMutable(a)), lda);
1150 }
1151
1152 bool CUDABlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n,
1153 std::complex<double> alpha,
1154 const DeviceMemory<std::complex<double>> &x, int incx,
1155 const DeviceMemory<std::complex<double>> &y, int incy,
1156 DeviceMemory<std::complex<double>> *a, int lda) {
1157 auto cb_alpha = GpuComplexValue(alpha);
1158 return DoBlasInternal(cublasZgerc, stream, true /* = pointer_mode_host */, m,
1159 n, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(x)),
1160 incx, GpuComplex(GpuMemory(y)), incy,
1161 GpuComplex(GpuMemoryMutable(a)), lda);
1162 }
1163
1164 bool CUDABlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n,
1165 std::complex<float> alpha,
1166 const DeviceMemory<std::complex<float>> &x, int incx,
1167 const DeviceMemory<std::complex<float>> &y, int incy,
1168 DeviceMemory<std::complex<float>> *a, int lda) {
1169 auto cb_alpha = GpuComplexValue(alpha);
1170 return DoBlasInternal(cublasCgeru, stream, true /* = pointer_mode_host */, m,
1171 n, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(x)),
1172 incx, GpuComplex(GpuMemory(y)), incy,
1173 GpuComplex(GpuMemoryMutable(a)), lda);
1174 }
1175
1176 bool CUDABlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n,
1177 std::complex<double> alpha,
1178 const DeviceMemory<std::complex<double>> &x, int incx,
1179 const DeviceMemory<std::complex<double>> &y, int incy,
1180 DeviceMemory<std::complex<double>> *a, int lda) {
1181 auto cb_alpha = GpuComplexValue(alpha);
1182 return DoBlasInternal(cublasZgeru, stream, true /* = pointer_mode_host */, m,
1183 n, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(x)),
1184 incx, GpuComplex(GpuMemory(y)), incy,
1185 GpuComplex(GpuMemoryMutable(a)), lda);
1186 }
1187
1188 bool CUDABlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1189 uint64 k, std::complex<float> alpha,
1190 const DeviceMemory<std::complex<float>> &a, int lda,
1191 const DeviceMemory<std::complex<float>> &x, int incx,
1192 std::complex<float> beta,
1193 DeviceMemory<std::complex<float>> *y, int incy) {
1194 auto cb_alpha = GpuComplexValue(alpha);
1195 auto cb_beta = GpuComplexValue(beta);
1196 return DoBlasInternal(cublasChbmv, stream, true /* = pointer_mode_host */,
1197 CUDABlasUpperLower(uplo), n, k, GpuComplex(&cb_alpha),
1198 GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
1199 incx, GpuComplex(&cb_beta),
1200 GpuComplex(GpuMemoryMutable(y)), incy);
1201 }
1202
1203 bool CUDABlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1204 uint64 k, std::complex<double> alpha,
1205 const DeviceMemory<std::complex<double>> &a, int lda,
1206 const DeviceMemory<std::complex<double>> &x, int incx,
1207 std::complex<double> beta,
1208 DeviceMemory<std::complex<double>> *y, int incy) {
1209 auto cb_alpha = GpuComplexValue(alpha);
1210 auto cb_beta = GpuComplexValue(beta);
1211 return DoBlasInternal(cublasZhbmv, stream, true /* = pointer_mode_host */,
1212 CUDABlasUpperLower(uplo), n, k, GpuComplex(&cb_alpha),
1213 GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
1214 incx, GpuComplex(&cb_beta),
1215 GpuComplex(GpuMemoryMutable(y)), incy);
1216 }
1217
1218 bool CUDABlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
1219 std::complex<float> alpha,
1220 const DeviceMemory<std::complex<float>> &a, int lda,
1221 const DeviceMemory<std::complex<float>> &x, int incx,
1222 std::complex<float> beta,
1223 DeviceMemory<std::complex<float>> *y, int incy) {
1224 auto cb_alpha = GpuComplexValue(alpha);
1225 auto cb_beta = GpuComplexValue(beta);
1226 return DoBlasInternal(cublasChemv, stream, true /* = pointer_mode_host */,
1227 CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
1228 GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
1229 incx, GpuComplex(&cb_beta),
1230 GpuComplex(GpuMemoryMutable(y)), incy);
1231 }
1232
1233 bool CUDABlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
1234 std::complex<double> alpha,
1235 const DeviceMemory<std::complex<double>> &a, int lda,
1236 const DeviceMemory<std::complex<double>> &x, int incx,
1237 std::complex<double> beta,
1238 DeviceMemory<std::complex<double>> *y, int incy) {
1239 auto cb_alpha = GpuComplexValue(alpha);
1240 auto cb_beta = GpuComplexValue(beta);
1241 return DoBlasInternal(cublasZhemv, stream, true /* = pointer_mode_host */,
1242 CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
1243 GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
1244 incx, GpuComplex(&cb_beta),
1245 GpuComplex(GpuMemoryMutable(y)), incy);
1246 }
1247
1248 bool CUDABlas::DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
1249 float alpha,
1250 const DeviceMemory<std::complex<float>> &x, int incx,
1251 DeviceMemory<std::complex<float>> *a, int lda) {
1252 return DoBlasInternal(cublasCher, stream, true /* = pointer_mode_host */,
1253 CUDABlasUpperLower(uplo), n, &alpha,
1254 GpuComplex(GpuMemory(x)), incx,
1255 GpuComplex(GpuMemoryMutable(a)), lda);
1256 }
1257
1258 bool CUDABlas::DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
1259 double alpha,
1260 const DeviceMemory<std::complex<double>> &x, int incx,
1261 DeviceMemory<std::complex<double>> *a, int lda) {
1262 return DoBlasInternal(cublasZher, stream, true /* = pointer_mode_host */,
1263 CUDABlasUpperLower(uplo), n, &alpha,
1264 GpuComplex(GpuMemory(x)), incx,
1265 GpuComplex(GpuMemoryMutable(a)), lda);
1266 }
1267
1268 bool CUDABlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
1269 std::complex<float> alpha,
1270 const DeviceMemory<std::complex<float>> &x, int incx,
1271 const DeviceMemory<std::complex<float>> &y, int incy,
1272 DeviceMemory<std::complex<float>> *a, int lda) {
1273 auto cb_alpha = GpuComplexValue(alpha);
1274 return DoBlasInternal(cublasCher2, stream, true /* = pointer_mode_host */,
1275 CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
1276 GpuComplex(GpuMemory(x)), incx,
1277 GpuComplex(GpuMemory(y)), incy,
1278 GpuComplex(GpuMemoryMutable(a)), lda);
1279 }
1280
1281 bool CUDABlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
1282 std::complex<double> alpha,
1283 const DeviceMemory<std::complex<double>> &x, int incx,
1284 const DeviceMemory<std::complex<double>> &y, int incy,
1285 DeviceMemory<std::complex<double>> *a, int lda) {
1286 auto cb_alpha = GpuComplexValue(alpha);
1287 return DoBlasInternal(cublasZher2, stream, true /* = pointer_mode_host */,
1288 CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
1289 GpuComplex(GpuMemory(x)), incx,
1290 GpuComplex(GpuMemory(y)), incy,
1291 GpuComplex(GpuMemoryMutable(a)), lda);
1292 }
1293
1294 bool CUDABlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1295 std::complex<float> alpha,
1296 const DeviceMemory<std::complex<float>> &ap,
1297 const DeviceMemory<std::complex<float>> &x, int incx,
1298 std::complex<float> beta,
1299 DeviceMemory<std::complex<float>> *y, int incy) {
1300 auto cb_alpha = GpuComplexValue(alpha);
1301 auto cb_beta = GpuComplexValue(beta);
1302 return DoBlasInternal(cublasChpmv, stream, true /* = pointer_mode_host */,
1303 CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
1304 GpuComplex(GpuMemory(ap)), GpuComplex(GpuMemory(x)),
1305 incx, GpuComplex(&cb_beta),
1306 GpuComplex(GpuMemoryMutable(y)), incy);
1307 }
1308
1309 bool CUDABlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1310 std::complex<double> alpha,
1311 const DeviceMemory<std::complex<double>> &ap,
1312 const DeviceMemory<std::complex<double>> &x, int incx,
1313 std::complex<double> beta,
1314 DeviceMemory<std::complex<double>> *y, int incy) {
1315 auto cb_alpha = GpuComplexValue(alpha);
1316 auto cb_beta = GpuComplexValue(beta);
1317 return DoBlasInternal(cublasZhpmv, stream, true /* = pointer_mode_host */,
1318 CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
1319 GpuComplex(GpuMemory(ap)), GpuComplex(GpuMemory(x)),
1320 incx, GpuComplex(&cb_beta),
1321 GpuComplex(GpuMemoryMutable(y)), incy);
1322 }
1323
1324 bool CUDABlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
1325 float alpha,
1326 const DeviceMemory<std::complex<float>> &x, int incx,
1327 DeviceMemory<std::complex<float>> *ap) {
1328 return DoBlasInternal(cublasChpr, stream, true /* = pointer_mode_host */,
1329 CUDABlasUpperLower(uplo), n, &alpha,
1330 GpuComplex(GpuMemory(x)), incx,
1331 GpuComplex(GpuMemoryMutable(ap)));
1332 }
1333
1334 bool CUDABlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
1335 double alpha,
1336 const DeviceMemory<std::complex<double>> &x, int incx,
1337 DeviceMemory<std::complex<double>> *ap) {
1338 return DoBlasInternal(cublasZhpr, stream, true /* = pointer_mode_host */,
1339 CUDABlasUpperLower(uplo), n, &alpha,
1340 GpuComplex(GpuMemory(x)), incx,
1341 GpuComplex(GpuMemoryMutable(ap)));
1342 }
1343
1344 bool CUDABlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1345 std::complex<float> alpha,
1346 const DeviceMemory<std::complex<float>> &x, int incx,
1347 const DeviceMemory<std::complex<float>> &y, int incy,
1348 DeviceMemory<std::complex<float>> *ap) {
1349 auto cb_alpha = GpuComplexValue(alpha);
1350 return DoBlasInternal(cublasChpr2, stream, true /* = pointer_mode_host */,
1351 CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
1352 GpuComplex(GpuMemory(x)), incx,
1353 GpuComplex(GpuMemory(y)), incy,
1354 GpuComplex(GpuMemoryMutable(ap)));
1355 }
1356
1357 bool CUDABlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1358 std::complex<double> alpha,
1359 const DeviceMemory<std::complex<double>> &x, int incx,
1360 const DeviceMemory<std::complex<double>> &y, int incy,
1361 DeviceMemory<std::complex<double>> *ap) {
1362 auto cb_alpha = GpuComplexValue(alpha);
1363 return DoBlasInternal(cublasZhpr2, stream, true /* = pointer_mode_host */,
1364 CUDABlasUpperLower(uplo), n, GpuComplex(&cb_alpha),
1365 GpuComplex(GpuMemory(x)), incx,
1366 GpuComplex(GpuMemory(y)), incy,
1367 GpuComplex(GpuMemoryMutable(ap)));
1368 }
1369
1370 bool CUDABlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1371 uint64 k, float alpha, const DeviceMemory<float> &a,
1372 int lda, const DeviceMemory<float> &x, int incx,
1373 float beta, DeviceMemory<float> *y, int incy) {
1374 return DoBlasInternal(cublasSsbmv, stream, true /* = pointer_mode_host */,
1375 CUDABlasUpperLower(uplo), n, k, &alpha, GpuMemory(a),
1376 lda, GpuMemory(x), incx, &beta, GpuMemoryMutable(y),
1377 incy);
1378 }
1379
1380 bool CUDABlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1381 uint64 k, double alpha, const DeviceMemory<double> &a,
1382 int lda, const DeviceMemory<double> &x, int incx,
1383 double beta, DeviceMemory<double> *y, int incy) {
1384 return DoBlasInternal(cublasDsbmv, stream, true /* = pointer_mode_host */,
1385 CUDABlasUpperLower(uplo), n, k, &alpha, GpuMemory(a),
1386 lda, GpuMemory(x), incx, &beta, GpuMemoryMutable(y),
1387 incy);
1388 }
1389
1390 bool CUDABlas::DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1391 float alpha, const DeviceMemory<float> &ap,
1392 const DeviceMemory<float> &x, int incx, float beta,
1393 DeviceMemory<float> *y, int incy) {
1394 return DoBlasInternal(cublasSspmv, stream, true /* = pointer_mode_host */,
1395 CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(ap),
1396 GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy);
1397 }
1398
1399 bool CUDABlas::DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1400 double alpha, const DeviceMemory<double> &ap,
1401 const DeviceMemory<double> &x, int incx, double beta,
1402 DeviceMemory<double> *y, int incy) {
1403 return DoBlasInternal(cublasDspmv, stream, true /* = pointer_mode_host */,
1404 CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(ap),
1405 GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy);
1406 }
1407
1408 bool CUDABlas::DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
1409 float alpha, const DeviceMemory<float> &x, int incx,
1410 DeviceMemory<float> *ap) {
1411 return DoBlasInternal(cublasSspr, stream, true /* = pointer_mode_host */,
1412 CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1413 GpuMemoryMutable(ap));
1414 }
1415
1416 bool CUDABlas::DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
1417 double alpha, const DeviceMemory<double> &x, int incx,
1418 DeviceMemory<double> *ap) {
1419 return DoBlasInternal(cublasDspr, stream, true /* = pointer_mode_host */,
1420 CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1421 GpuMemoryMutable(ap));
1422 }
1423
1424 bool CUDABlas::DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1425 float alpha, const DeviceMemory<float> &x, int incx,
1426 const DeviceMemory<float> &y, int incy,
1427 DeviceMemory<float> *ap) {
1428 return DoBlasInternal(cublasSspr2, stream, true /* = pointer_mode_host */,
1429 CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1430 GpuMemory(y), incy, GpuMemoryMutable(ap));
1431 }
1432
1433 bool CUDABlas::DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1434 double alpha, const DeviceMemory<double> &x, int incx,
1435 const DeviceMemory<double> &y, int incy,
1436 DeviceMemory<double> *ap) {
1437 return DoBlasInternal(cublasDspr2, stream, true /* = pointer_mode_host */,
1438 CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1439 GpuMemory(y), incy, GpuMemoryMutable(ap));
1440 }
1441
1442 bool CUDABlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
1443 float alpha, const DeviceMemory<float> &a, int lda,
1444 const DeviceMemory<float> &x, int incx, float beta,
1445 DeviceMemory<float> *y, int incy) {
1446 return DoBlasInternal(cublasSsymv, stream, true /* = pointer_mode_host */,
1447 CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(a), lda,
1448 GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy);
1449 }
1450
1451 bool CUDABlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
1452 double alpha, const DeviceMemory<double> &a, int lda,
1453 const DeviceMemory<double> &x, int incx, double beta,
1454 DeviceMemory<double> *y, int incy) {
1455 return DoBlasInternal(cublasDsymv, stream, true /* = pointer_mode_host */,
1456 CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(a), lda,
1457 GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy);
1458 }
1459
1460 bool CUDABlas::DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
1461 float alpha, const DeviceMemory<float> &x, int incx,
1462 DeviceMemory<float> *a, int lda) {
1463 return DoBlasInternal(cublasSsyr, stream, true /* = pointer_mode_host */,
1464 CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1465 GpuMemoryMutable(a), lda);
1466 }
1467
1468 bool CUDABlas::DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
1469 double alpha, const DeviceMemory<double> &x, int incx,
1470 DeviceMemory<double> *a, int lda) {
1471 return DoBlasInternal(cublasDsyr, stream, true /* = pointer_mode_host */,
1472 CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1473 GpuMemoryMutable(a), lda);
1474 }
1475
1476 bool CUDABlas::DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1477 float alpha, const DeviceMemory<float> &x, int incx,
1478 const DeviceMemory<float> &y, int incy,
1479 DeviceMemory<float> *a, int lda) {
1480 return DoBlasInternal(cublasSsyr2, stream, true /* = pointer_mode_host */,
1481 CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1482 GpuMemory(y), incy, GpuMemoryMutable(a), lda);
1483 }
1484
1485 bool CUDABlas::DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1486 double alpha, const DeviceMemory<double> &x, int incx,
1487 const DeviceMemory<double> &y, int incy,
1488 DeviceMemory<double> *a, int lda) {
1489 return DoBlasInternal(cublasDsyr2, stream, true /* = pointer_mode_host */,
1490 CUDABlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1491 GpuMemory(y), incy, GpuMemoryMutable(a), lda);
1492 }
1493
1494 bool CUDABlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
1495 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1496 uint64 k, const DeviceMemory<float> &a, int lda,
1497 DeviceMemory<float> *x, int incx) {
1498 return DoBlasInternal(cublasStbmv, stream, true /* = pointer_mode_host */,
1499 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1500 CUDABlasDiagonal(diag), n, k, GpuMemory(a), lda,
1501 GpuMemoryMutable(x), incx);
1502 }
1503
1504 bool CUDABlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
1505 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1506 uint64 k, const DeviceMemory<double> &a, int lda,
1507 DeviceMemory<double> *x, int incx) {
1508 return DoBlasInternal(cublasDtbmv, stream, true /* = pointer_mode_host */,
1509 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1510 CUDABlasDiagonal(diag), n, k, GpuMemory(a), lda,
1511 GpuMemoryMutable(x), incx);
1512 }
1513
1514 bool CUDABlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
1515 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1516 uint64 k, const DeviceMemory<std::complex<float>> &a,
1517 int lda, DeviceMemory<std::complex<float>> *x,
1518 int incx) {
1519 return DoBlasInternal(cublasCtbmv, stream, true /* = pointer_mode_host */,
1520 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1521 CUDABlasDiagonal(diag), n, k, GpuComplex(GpuMemory(a)),
1522 lda, GpuComplex(GpuMemoryMutable(x)), incx);
1523 }
1524
1525 bool CUDABlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
1526 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1527 uint64 k, const DeviceMemory<std::complex<double>> &a,
1528 int lda, DeviceMemory<std::complex<double>> *x,
1529 int incx) {
1530 return DoBlasInternal(cublasZtbmv, stream, true /* = pointer_mode_host */,
1531 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1532 CUDABlasDiagonal(diag), n, k, GpuComplex(GpuMemory(a)),
1533 lda, GpuComplex(GpuMemoryMutable(x)), incx);
1534 }
1535
1536 bool CUDABlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
1537 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1538 uint64 k, const DeviceMemory<float> &a, int lda,
1539 DeviceMemory<float> *x, int incx) {
1540 return DoBlasInternal(cublasStbsv, stream, true /* = pointer_mode_host */,
1541 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1542 CUDABlasDiagonal(diag), n, k, GpuMemory(a), lda,
1543 GpuMemoryMutable(x), incx);
1544 }
1545
1546 bool CUDABlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
1547 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1548 uint64 k, const DeviceMemory<double> &a, int lda,
1549 DeviceMemory<double> *x, int incx) {
1550 return DoBlasInternal(cublasDtbsv, stream, true /* = pointer_mode_host */,
1551 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1552 CUDABlasDiagonal(diag), n, k, GpuMemory(a), lda,
1553 GpuMemoryMutable(x), incx);
1554 }
1555
1556 bool CUDABlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
1557 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1558 uint64 k, const DeviceMemory<std::complex<float>> &a,
1559 int lda, DeviceMemory<std::complex<float>> *x,
1560 int incx) {
1561 return DoBlasInternal(cublasCtbsv, stream, true /* = pointer_mode_host */,
1562 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1563 CUDABlasDiagonal(diag), n, k, GpuComplex(GpuMemory(a)),
1564 lda, GpuComplex(GpuMemoryMutable(x)), incx);
1565 }
1566
1567 bool CUDABlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
1568 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1569 uint64 k, const DeviceMemory<std::complex<double>> &a,
1570 int lda, DeviceMemory<std::complex<double>> *x,
1571 int incx) {
1572 return DoBlasInternal(cublasZtbsv, stream, true /* = pointer_mode_host */,
1573 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1574 CUDABlasDiagonal(diag), n, k, GpuComplex(GpuMemory(a)),
1575 lda, GpuComplex(GpuMemoryMutable(x)), incx);
1576 }
1577
1578 bool CUDABlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
1579 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1580 const DeviceMemory<float> &ap, DeviceMemory<float> *x,
1581 int incx) {
1582 return DoBlasInternal(cublasStpmv, stream, true /* = pointer_mode_host */,
1583 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1584 CUDABlasDiagonal(diag), n, GpuMemory(ap),
1585 GpuMemoryMutable(x), incx);
1586 }
1587
1588 bool CUDABlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
1589 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1590 const DeviceMemory<double> &ap,
1591 DeviceMemory<double> *x, int incx) {
1592 return DoBlasInternal(cublasDtpmv, stream, true /* = pointer_mode_host */,
1593 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1594 CUDABlasDiagonal(diag), n, GpuMemory(ap),
1595 GpuMemoryMutable(x), incx);
1596 }
1597
1598 bool CUDABlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
1599 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1600 const DeviceMemory<std::complex<float>> &ap,
1601 DeviceMemory<std::complex<float>> *x, int incx) {
1602 return DoBlasInternal(cublasCtpmv, stream, true /* = pointer_mode_host */,
1603 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1604 CUDABlasDiagonal(diag), n, GpuComplex(GpuMemory(ap)),
1605 GpuComplex(GpuMemoryMutable(x)), incx);
1606 }
1607
1608 bool CUDABlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
1609 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1610 const DeviceMemory<std::complex<double>> &ap,
1611 DeviceMemory<std::complex<double>> *x, int incx) {
1612 return DoBlasInternal(cublasZtpmv, stream, true /* = pointer_mode_host */,
1613 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1614 CUDABlasDiagonal(diag), n, GpuComplex(GpuMemory(ap)),
1615 GpuComplex(GpuMemoryMutable(x)), incx);
1616 }
1617
1618 bool CUDABlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
1619 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1620 const DeviceMemory<float> &ap, DeviceMemory<float> *x,
1621 int incx) {
1622 return DoBlasInternal(cublasStpsv, stream, true /* = pointer_mode_host */,
1623 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1624 CUDABlasDiagonal(diag), n, GpuMemory(ap),
1625 GpuMemoryMutable(x), incx);
1626 }
1627
1628 bool CUDABlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
1629 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1630 const DeviceMemory<double> &ap,
1631 DeviceMemory<double> *x, int incx) {
1632 return DoBlasInternal(cublasDtpsv, stream, true /* = pointer_mode_host */,
1633 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1634 CUDABlasDiagonal(diag), n, GpuMemory(ap),
1635 GpuMemoryMutable(x), incx);
1636 }
1637
1638 bool CUDABlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
1639 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1640 const DeviceMemory<std::complex<float>> &ap,
1641 DeviceMemory<std::complex<float>> *x, int incx) {
1642 return DoBlasInternal(cublasCtpsv, stream, true /* = pointer_mode_host */,
1643 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1644 CUDABlasDiagonal(diag), n, GpuComplex(GpuMemory(ap)),
1645 GpuComplex(GpuMemoryMutable(x)), incx);
1646 }
1647
1648 bool CUDABlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
1649 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1650 const DeviceMemory<std::complex<double>> &ap,
1651 DeviceMemory<std::complex<double>> *x, int incx) {
1652 return DoBlasInternal(cublasZtpsv, stream, true /* = pointer_mode_host */,
1653 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1654 CUDABlasDiagonal(diag), n, GpuComplex(GpuMemory(ap)),
1655 GpuComplex(GpuMemoryMutable(x)), incx);
1656 }
1657
1658 bool CUDABlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
1659 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1660 const DeviceMemory<float> &a, int lda,
1661 DeviceMemory<float> *x, int incx) {
1662 return DoBlasInternal(cublasStrmv, stream, true /* = pointer_mode_host */,
1663 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1664 CUDABlasDiagonal(diag), n, GpuMemory(a), lda,
1665 GpuMemoryMutable(x), incx);
1666 }
1667
1668 bool CUDABlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
1669 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1670 const DeviceMemory<double> &a, int lda,
1671 DeviceMemory<double> *x, int incx) {
1672 return DoBlasInternal(cublasDtrmv, stream, true /* = pointer_mode_host */,
1673 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1674 CUDABlasDiagonal(diag), n, GpuMemory(a), lda,
1675 GpuMemoryMutable(x), incx);
1676 }
1677
1678 bool CUDABlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
1679 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1680 const DeviceMemory<std::complex<float>> &a, int lda,
1681 DeviceMemory<std::complex<float>> *x, int incx) {
1682 return DoBlasInternal(cublasCtrmv, stream, true /* = pointer_mode_host */,
1683 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1684 CUDABlasDiagonal(diag), n, GpuComplex(GpuMemory(a)),
1685 lda, GpuComplex(GpuMemoryMutable(x)), incx);
1686 }
1687
1688 bool CUDABlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
1689 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1690 const DeviceMemory<std::complex<double>> &a, int lda,
1691 DeviceMemory<std::complex<double>> *x, int incx) {
1692 return DoBlasInternal(cublasZtrmv, stream, true /* = pointer_mode_host */,
1693 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1694 CUDABlasDiagonal(diag), n, GpuComplex(GpuMemory(a)),
1695 lda, GpuComplex(GpuMemoryMutable(x)), incx);
1696 }
1697
1698 bool CUDABlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1699 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1700 const DeviceMemory<float> &a, int lda,
1701 DeviceMemory<float> *x, int incx) {
1702 return DoBlasInternal(cublasStrsv, stream, true /* = pointer_mode_host */,
1703 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1704 CUDABlasDiagonal(diag), n, GpuMemory(a), lda,
1705 GpuMemoryMutable(x), incx);
1706 }
1707
1708 bool CUDABlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1709 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1710 const DeviceMemory<double> &a, int lda,
1711 DeviceMemory<double> *x, int incx) {
1712 return DoBlasInternal(cublasDtrsv, stream, true /* = pointer_mode_host */,
1713 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1714 CUDABlasDiagonal(diag), n, GpuMemory(a), lda,
1715 GpuMemoryMutable(x), incx);
1716 }
1717
1718 bool CUDABlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1719 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1720 const DeviceMemory<std::complex<float>> &a, int lda,
1721 DeviceMemory<std::complex<float>> *x, int incx) {
1722 return DoBlasInternal(cublasCtrsv, stream, true /* = pointer_mode_host */,
1723 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1724 CUDABlasDiagonal(diag), n, GpuComplex(GpuMemory(a)),
1725 lda, GpuComplex(GpuMemoryMutable(x)), incx);
1726 }
1727
1728 bool CUDABlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1729 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1730 const DeviceMemory<std::complex<double>> &a, int lda,
1731 DeviceMemory<std::complex<double>> *x, int incx) {
1732 return DoBlasInternal(cublasZtrsv, stream, true /* = pointer_mode_host */,
1733 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
1734 CUDABlasDiagonal(diag), n, GpuComplex(GpuMemory(a)),
1735 lda, GpuComplex(GpuMemoryMutable(x)), incx);
1736 }
1737
1738 bool CUDABlas::DoBlasGemm(
1739 Stream *stream, blas::Transpose transa,
1740 blas::Transpose transb, uint64 m, uint64 n, uint64 k,
1741 float alpha, const DeviceMemory<Eigen::half> &a, int lda,
1742 const DeviceMemory<Eigen::half> &b, int ldb, float beta,
1743 DeviceMemory<Eigen::half> *c, int ldc) {
1744 #if CUDA_VERSION >= 7050
1745 VLOG(1) << absl::StrFormat(
1746 "doing cuBLAS SGEMM: at=%d bt=%d m=%u n=%u "
1747 "k=%u alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
1748 "c=%p ldc=%d",
1749 static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
1750 a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
1751 if (transa == blas::Transpose::kNoTranspose) {
1752 if (lda < static_cast<int64>(m)) {
1753 LOG(WARNING) << "GEMM lda was smaller than m (no transpose case); "
1754 "precondition violation";
1755 }
1756 } else {
1757 if (lda < static_cast<int64>(k)) {
1758 LOG(WARNING) << "GEMM lda (" << lda << ") was smaller than k (" << k
1759 << ") (transpose case); precondition violation";
1760 }
1761 }
1762 if (transb == blas::Transpose::kNoTranspose) {
1763 if (ldb < static_cast<int64>(k)) {
1764 LOG(WARNING) << "GEMM ldb (" << ldb << ") was smaller than k (" << k
1765 << ") (no transpose case); precondition violation";
1766 }
1767 } else {
1768 if (ldb < static_cast<int64>(n)) {
1769 LOG(WARNING) << "GEMM ldb was smaller than n (transpose case); "
1770 "precondition violation";
1771 }
1772 }
1773
1774 #if CUDA_VERSION < 11000
1775 cublasMath_t math_type = CUBLAS_TENSOR_OP_MATH;
1776 #else
1777 cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
1778 #endif
1779
1780 return DoBlasInternalImpl(
1781 cublasSgemmEx, stream, true /* = pointer_mode_host */,
1782 true /* = err_on_failure= */, math_type, CUDABlasTranspose(transa),
1783 CUDABlasTranspose(transb), m, n, k, &alpha, GpuMemory(a),
1784 SE_CUDA_DATA_HALF, lda, GpuMemory(b), SE_CUDA_DATA_HALF, ldb, &beta,
1785 GpuMemoryMutable(c), SE_CUDA_DATA_HALF, ldc);
1786
1787 #else
1788 LOG(ERROR) << "fp16 sgemm is not implemented in this cuBLAS version "
1789 << "(need at least CUDA 7.5)";
1790 return false;
1791 #endif
1792 }
1793
1794 bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
1795 blas::Transpose transb, uint64 m, uint64 n, uint64 k,
1796 float alpha, const DeviceMemory<float> &a, int lda,
1797 const DeviceMemory<float> &b, int ldb, float beta,
1798 DeviceMemory<float> *c, int ldc) {
1799 VLOG(1) << absl::StrFormat(
1800 "doing cuBLAS SGEMM: at=%d bt=%d m=%u n=%u "
1801 "k=%u alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
1802 "c=%p ldc=%d",
1803 static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
1804 a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
1805 if (transa == blas::Transpose::kNoTranspose) {
1806 if (lda < static_cast<int64>(m)) {
1807 LOG(WARNING) << "GEMM lda was smaller than m (no transpose case); "
1808 "precondition violation";
1809 }
1810 } else {
1811 if (lda < static_cast<int64>(k)) {
1812 LOG(WARNING) << "GEMM lda (" << lda << ") was smaller than k (" << k
1813 << ") (transpose case); precondition violation";
1814 }
1815 }
1816 if (transb == blas::Transpose::kNoTranspose) {
1817 if (ldb < static_cast<int64>(k)) {
1818 LOG(WARNING) << "GEMM ldb (" << ldb << ") was smaller than k (" << k
1819 << ") (no transpose case); precondition violation";
1820 }
1821 } else {
1822 if (ldb < static_cast<int64>(n)) {
1823 LOG(WARNING) << "GEMM ldb was smaller than n (transpose case); "
1824 "precondition violation";
1825 }
1826 }
1827
1828 #if CUDA_VERSION < 11000
1829 cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
1830 #else
1831 cublasMath_t math_type = CUBLAS_TF32_TENSOR_OP_MATH;
1832 int cc_major, cc_minor;
1833 if (stream->parent()->GetDeviceDescription().cuda_compute_capability(
1834 &cc_major, &cc_minor) &&
1835 cc_major >= 8) {
1836 // TODO(reedwm): Remove or make this VLOG(1) once TensorFloat-32 is more
1837 // well tested.
1838 LOG_FIRST_N(INFO, 1) << "TensorFloat-32 will be used for the matrix "
1839 "multiplication. This will only be logged once.";
1840 }
1841 #endif
1842
1843 return DoBlasInternalImpl(
1844 cublasSgemm, stream, true /* = pointer_mode_host */,
1845 true /* = err_on_failure */, math_type, CUDABlasTranspose(transa),
1846 CUDABlasTranspose(transb), m, n, k, &alpha, GpuMemory(a), lda,
1847 GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc);
1848 }
1849
1850 bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
1851 blas::Transpose transb, uint64 m, uint64 n, uint64 k,
1852 double alpha, const DeviceMemory<double> &a, int lda,
1853 const DeviceMemory<double> &b, int ldb, double beta,
1854 DeviceMemory<double> *c, int ldc) {
1855 return DoBlasInternal(cublasDgemm, stream, true /* = pointer_mode_host */,
1856 CUDABlasTranspose(transa), CUDABlasTranspose(transb), m,
1857 n, k, &alpha, GpuMemory(a), lda, GpuMemory(b), ldb,
1858 &beta, GpuMemoryMutable(c), ldc);
1859 }
1860
1861 bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
1862 blas::Transpose transb, uint64 m, uint64 n, uint64 k,
1863 std::complex<float> alpha,
1864 const DeviceMemory<std::complex<float>> &a, int lda,
1865 const DeviceMemory<std::complex<float>> &b, int ldb,
1866 std::complex<float> beta,
1867 DeviceMemory<std::complex<float>> *c, int ldc) {
1868 auto cb_alpha = GpuComplexValue(alpha);
1869 auto cb_beta = GpuComplexValue(beta);
1870 return DoBlasInternal(cublasCgemm, stream, true /* = pointer_mode_host */,
1871 CUDABlasTranspose(transa), CUDABlasTranspose(transb), m,
1872 n, k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)),
1873 lda, GpuComplex(GpuMemory(b)), ldb,
1874 GpuComplex(&cb_beta), GpuComplex(GpuMemoryMutable(c)),
1875 ldc);
1876 }
1877
1878 bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
1879 blas::Transpose transb, uint64 m, uint64 n, uint64 k,
1880 std::complex<double> alpha,
1881 const DeviceMemory<std::complex<double>> &a, int lda,
1882 const DeviceMemory<std::complex<double>> &b, int ldb,
1883 std::complex<double> beta,
1884 DeviceMemory<std::complex<double>> *c, int ldc) {
1885 auto cb_alpha = GpuComplexValue(alpha);
1886 auto cb_beta = GpuComplexValue(beta);
1887 return DoBlasInternal(cublasZgemm, stream, true /* = pointer_mode_host */,
1888 CUDABlasTranspose(transa), CUDABlasTranspose(transb), m,
1889 n, k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)),
1890 lda, GpuComplex(GpuMemory(b)), ldb,
1891 GpuComplex(&cb_beta), GpuComplex(GpuMemoryMutable(c)),
1892 ldc);
1893 }
1894
1895 bool CUDABlas::DoBlasGemvWithProfiling(
1896 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha,
1897 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
1898 int incx, float beta, DeviceMemory<float> *y, int incy,
1899 blas::ProfileResult *output_profile_result) {
1900 return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
1901 incx, beta, y, incy,
1902 output_profile_result);
1903 }
1904
1905 bool CUDABlas::DoBlasGemvWithProfiling(
1906 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha,
1907 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
1908 int incx, double beta, DeviceMemory<double> *y, int incy,
1909 blas::ProfileResult *output_profile_result) {
1910 return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
1911 incx, beta, y, incy,
1912 output_profile_result);
1913 }
1914
1915 bool CUDABlas::DoBlasGemvWithProfiling(
1916 Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
1917 std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a,
1918 int lda, const DeviceMemory<std::complex<float>> &x, int incx,
1919 std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
1920 blas::ProfileResult *output_profile_result) {
1921 return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
1922 incx, beta, y, incy,
1923 output_profile_result);
1924 }
1925
1926 bool CUDABlas::DoBlasGemvWithProfiling(
1927 Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
1928 std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a,
1929 int lda, const DeviceMemory<std::complex<double>> &x, int incx,
1930 std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy,
1931 blas::ProfileResult *output_profile_result) {
1932 return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
1933 incx, beta, y, incy,
1934 output_profile_result);
1935 }
1936
1937 bool CUDABlas::DoBlasGemmWithProfiling(
1938 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1939 uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
1940 int lda, const DeviceMemory<Eigen::half> &b, int ldb, float beta,
1941 DeviceMemory<Eigen::half> *c, int ldc,
1942 blas::ProfileResult *output_profile_result) {
1943 return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1944 lda, b, ldb, beta, c, ldc,
1945 output_profile_result);
1946 }
1947
1948 bool CUDABlas::DoBlasGemmWithProfiling(
1949 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1950 uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
1951 const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
1952 int ldc, blas::ProfileResult *output_profile_result) {
1953 return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1954 lda, b, ldb, beta, c, ldc,
1955 output_profile_result);
1956 }
1957
1958 bool CUDABlas::DoBlasGemmWithProfiling(
1959 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1960 uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
1961 const DeviceMemory<double> &b, int ldb, double beta,
1962 DeviceMemory<double> *c, int ldc,
1963 blas::ProfileResult *output_profile_result) {
1964 return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1965 lda, b, ldb, beta, c, ldc,
1966 output_profile_result);
1967 }
1968
1969 bool CUDABlas::DoBlasGemmWithProfiling(
1970 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1971 uint64 n, uint64 k, std::complex<float> alpha,
1972 const DeviceMemory<std::complex<float>> &a, int lda,
1973 const DeviceMemory<std::complex<float>> &b, int ldb,
1974 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
1975 blas::ProfileResult *output_profile_result) {
1976 return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1977 lda, b, ldb, beta, c, ldc,
1978 output_profile_result);
1979 }
1980
1981 bool CUDABlas::DoBlasGemmWithProfiling(
1982 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1983 uint64 n, uint64 k, std::complex<double> alpha,
1984 const DeviceMemory<std::complex<double>> &a, int lda,
1985 const DeviceMemory<std::complex<double>> &b, int ldb,
1986 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
1987 blas::ProfileResult *output_profile_result) {
1988 return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1989 lda, b, ldb, beta, c, ldc,
1990 output_profile_result);
1991 }
1992
1993 template <typename T>
1994 bool CUDABlas::DoBlasGemvWithProfilingImpl(
1995 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, const T &alpha,
1996 const DeviceMemory<T> &a, int lda, const DeviceMemory<T> &x, int incx,
1997 const T &beta, DeviceMemory<T> *y, int incy,
1998 blas::ProfileResult *output_profile_result) {
1999 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
2000 if (output_profile_result != nullptr) {
2001 timer.reset(new GpuTimer(parent_));
2002 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
2003 return false;
2004 }
2005 }
2006
2007 // Call blasGemm
2008 bool result =
2009 DoBlasGemv(stream, trans, m, n, alpha, a, lda, x, incx, beta, y, incy);
2010
2011 if (timer != nullptr && result) {
2012 // GpuTimer will CHECK-fail if we Stop() it while the stream is in an error
2013 // state.
2014 if (!timer->Stop(AsGpuStream(stream))) {
2015 return false;
2016 }
2017 output_profile_result->set_is_valid(true);
2018 output_profile_result->set_algorithm(blas::kDefaultBlasGemv);
2019 output_profile_result->set_elapsed_time_in_ms(
2020 timer->GetElapsedMilliseconds());
2021 }
2022 return result;
2023 }
2024
2025 template <typename T, typename ParamType>
2026 bool CUDABlas::DoBlasGemmWithProfilingImpl(
2027 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2028 uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a,
2029 int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta,
2030 DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result) {
2031 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
2032 if (output_profile_result != nullptr) {
2033 timer.reset(new GpuTimer(parent_));
2034 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
2035 return false;
2036 }
2037 }
2038
2039 // Call blasGemm
2040 bool result = DoBlasGemm(stream, transa, transb, m, n, k, alpha, a, lda, b,
2041 ldb, beta, c, ldc);
2042
2043 if (timer != nullptr && result) {
2044 // GpuTimer will CHECK-fail if we Stop() it while the stream is in an error
2045 // state.
2046 if (!timer->Stop(AsGpuStream(stream))) {
2047 return false;
2048 }
2049 output_profile_result->set_is_valid(true);
2050 output_profile_result->set_algorithm(blas::kDefaultBlasGemm);
2051 output_profile_result->set_elapsed_time_in_ms(
2052 timer->GetElapsedMilliseconds());
2053 }
2054 return result;
2055 }
2056
2057 static bool UsesTensorOps(blas::AlgorithmType algo) {
2058 #if CUDA_VERSION >= 9000
2059 cublasGemmAlgo_t cublas_algo = static_cast<cublasGemmAlgo_t>(algo);
2060 return cublas_algo >= CUBLAS_GEMM_DEFAULT_TENSOR_OP;
2061 #else
2062 return false;
2063 #endif
2064 }
2065
2066 template <typename InT, typename OutT, typename CompT>
2067 bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
2068 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2069 uint64 n, uint64 k, const HostOrDeviceScalar<CompT> &alpha,
2070 const DeviceMemory<InT> &a, int lda, const DeviceMemory<InT> &b, int ldb,
2071 const HostOrDeviceScalar<CompT> &beta, DeviceMemory<OutT> *c, int ldc,
2072 blas::ComputationType computation_type, blas::AlgorithmType algorithm,
2073 blas::ProfileResult *output_profile_result) {
2074 // GPUs < sm_50 don't support cublasGemmEx.
2075 int cc_major, cc_minor;
2076 if (stream->parent()->GetDeviceDescription().cuda_compute_capability(
2077 &cc_major, &cc_minor) &&
2078 cc_major < 5) {
2079 VLOG(2) << "DoBlasGemmWithAlgorithm returning false because sm" << cc_major
2080 << cc_minor << " devices don't support explicit gemm algorithms.";
2081 return false;
2082 }
2083
2084 bool algo_uses_tensor_ops = UsesTensorOps(algorithm);
2085 cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
2086 if (algo_uses_tensor_ops) {
2087 if (cc_major < 7) {
2088 VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm "
2089 << algorithm
2090 << " uses tensor ops, but tensor ops are not available in sm"
2091 << cc_major << "X devices.";
2092 return false;
2093 } else if (std::is_same<InT, float>::value) {
2094 #if CUDA_VERSION < 11000
2095 VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm "
2096 << algorithm
2097 << " uses tensor ops, but tensor ops are not available for fp32"
2098 << " inputs.";
2099 return false;
2100 #else
2101 if (cc_major < 8) {
2102 VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm "
2103 << algorithm
2104 << " uses tensor ops, but tensor ops are not available in sm"
2105 << cc_major << "X devices for float input types.";
2106 return false;
2107 } else if (!tensorflow::tensor_float_32_execution_enabled()) {
2108 VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm "
2109 << algorithm
2110 << " uses tensor ops, but tensor ops are disabled for fp32"
2111 << " inputs.";
2112 return false;
2113 }
2114 math_type = CUBLAS_TF32_TENSOR_OP_MATH;
2115 #endif
2116 } else if (std::is_same<InT, Eigen::half>::value) {
2117 #if CUDA_VERSION < 11000
2118 math_type = CUBLAS_TENSOR_OP_MATH;
2119 #endif
2120 } else {
2121 VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm "
2122 << algorithm
2123 << " uses tensor ops, which are not supported for InT.";
2124 return false;
2125 }
2126 }
2127
2128 // Either both 'alpha' and 'beta' need to be pointers to device memory, or
2129 // they need to be both host scalars.
2130 if (alpha.is_pointer() != beta.is_pointer()) {
2131 VLOG(2) << "DoBlasGemmWithAlgorithm returning false because one of `alpha` "
2132 "and `beta` is a pointer, but the other is not.";
2133 return false;
2134 }
2135
2136 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
2137 if (output_profile_result != nullptr) {
2138 timer.reset(new GpuTimer(parent_));
2139 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
2140 VLOG(2) << "DoBlasGemmWithAlgorithm returning false because "
2141 "output_profile_result was given, but we were unable to "
2142 "create a GpuTimer.";
2143 return false;
2144 }
2145 }
2146
2147 // Return false if we might be hitting a cuBLAS bug that produces the wrong
2148 // result. See nvbugs/2156201, b/79126339.
2149 #if CUDA_VERSION >= 9000 && CUDA_VERSION < 9020
2150 if ((algorithm == CUBLAS_GEMM_DEFAULT || algorithm >= CUBLAS_GEMM_ALGO13) &&
2151 std::max({m, n, k}) >= 2097153 && cc_major < 7) {
2152 VLOG(2) << "DoBlasGemmWithAlgorithm returning false to work around cudnn "
2153 "<9.2 bug with m, n, or k >= 2097153. See b/79126339.";
2154 return false;
2155 }
2156 #endif
2157
2158 cudaDataType_t cuda_in_type = CUDADataType<InT>::type;
2159 // Since we are converting 'algorithm' to cublasGemmAlgo_t by static_cast,
2160 // we do the following compile-time check on the default value:
2161 static_assert(blas::kDefaultGemmAlgo == CUBLAS_GEMM_DFALT, "");
2162 // If 'alpha' and 'beta' are host scalars and CompT is Eigen::half, we
2163 // essentially reinterpet_cast to __half, which is safe because Eigen::half
2164 // inherits from __half.
2165 bool result = DoBlasInternalImpl(
2166 AS_LAMBDA(cublasGemmEx), stream,
2167 /* pointer_mode_host = */ !alpha.is_pointer(), /*err_on_failure=*/false,
2168 math_type, CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
2169 alpha.is_pointer() ? GpuMemory(alpha.pointer()) : &alpha.value(),
2170 GpuMemory(a), cuda_in_type, lda, GpuMemory(b), cuda_in_type, ldb,
2171 beta.is_pointer() ? GpuMemory(beta.pointer()) : &beta.value(),
2172 GpuMemoryMutable(c), CUDADataType<OutT>::type, ldc,
2173 CUDAComputationType(computation_type),
2174 static_cast<cublasGemmAlgo_t>(algorithm));
2175
2176 if (timer != nullptr && result) {
2177 // GpuTimer will CHECK-fail if we Stop() it while the stream is in an error
2178 // state.
2179 if (!timer->Stop(AsGpuStream(stream))) {
2180 VLOG(2) << "DoBlasGemmWithAlgorithm returning false; unable to stop "
2181 "GpuTimer.";
2182 return false;
2183 }
2184 output_profile_result->set_is_valid(true);
2185 output_profile_result->set_algorithm(algorithm);
2186 output_profile_result->set_elapsed_time_in_ms(
2187 timer->GetElapsedMilliseconds());
2188 }
2189 return result;
2190 }
2191
2192 bool CUDABlas::GetBlasGemmAlgorithms(
2193 std::vector<blas::AlgorithmType> *out_algorithms) {
2194 // cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx)
2195 // were first introduced in CUDA 8.
2196 //
2197 // Note that when CUDA version and compute capability is not sufficient, we
2198 // still return the out_algorithms. Caller needs to make sure that in this
2199 // case, the returned vector is empty.
2200 *out_algorithms = {
2201 CUBLAS_GEMM_DFALT,
2202 CUBLAS_GEMM_ALGO0,
2203 CUBLAS_GEMM_ALGO1,
2204 CUBLAS_GEMM_ALGO2,
2205 CUBLAS_GEMM_ALGO3,
2206 CUBLAS_GEMM_ALGO4,
2207 CUBLAS_GEMM_ALGO5,
2208 CUBLAS_GEMM_ALGO6,
2209 CUBLAS_GEMM_ALGO7,
2210 #if CUDA_VERSION >= 9000
2211 CUBLAS_GEMM_ALGO8,
2212 CUBLAS_GEMM_ALGO9,
2213 CUBLAS_GEMM_ALGO10,
2214 CUBLAS_GEMM_ALGO11,
2215 CUBLAS_GEMM_ALGO12,
2216 CUBLAS_GEMM_ALGO13,
2217 CUBLAS_GEMM_ALGO14,
2218 CUBLAS_GEMM_ALGO15,
2219 CUBLAS_GEMM_ALGO16,
2220 CUBLAS_GEMM_ALGO17,
2221 CUBLAS_GEMM_DFALT_TENSOR_OP,
2222 CUBLAS_GEMM_ALGO0_TENSOR_OP,
2223 CUBLAS_GEMM_ALGO1_TENSOR_OP,
2224 CUBLAS_GEMM_ALGO2_TENSOR_OP,
2225 CUBLAS_GEMM_ALGO3_TENSOR_OP,
2226 CUBLAS_GEMM_ALGO4_TENSOR_OP,
2227 #endif
2228 #if CUDA_VERSION >= 9020
2229 CUBLAS_GEMM_ALGO18,
2230 CUBLAS_GEMM_ALGO19,
2231 CUBLAS_GEMM_ALGO20,
2232 CUBLAS_GEMM_ALGO21,
2233 CUBLAS_GEMM_ALGO22,
2234 CUBLAS_GEMM_ALGO23,
2235 CUBLAS_GEMM_ALGO5_TENSOR_OP,
2236 CUBLAS_GEMM_ALGO6_TENSOR_OP,
2237 CUBLAS_GEMM_ALGO7_TENSOR_OP,
2238 CUBLAS_GEMM_ALGO8_TENSOR_OP,
2239 CUBLAS_GEMM_ALGO9_TENSOR_OP,
2240 CUBLAS_GEMM_ALGO10_TENSOR_OP,
2241 CUBLAS_GEMM_ALGO11_TENSOR_OP,
2242 CUBLAS_GEMM_ALGO12_TENSOR_OP,
2243 CUBLAS_GEMM_ALGO13_TENSOR_OP,
2244 CUBLAS_GEMM_ALGO14_TENSOR_OP,
2245 CUBLAS_GEMM_ALGO15_TENSOR_OP,
2246 #endif
2247 };
2248 return true;
2249 }
2250
2251 bool CUDABlas::DoBlasGemmWithAlgorithm(
2252 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2253 uint64 n, uint64 k, const HostOrDeviceScalar<int> &alpha,
2254 const DeviceMemory<int8> &a, int lda, const DeviceMemory<int8> &b, int ldb,
2255 const HostOrDeviceScalar<int> &beta, DeviceMemory<int> *c, int ldc,
2256 blas::ComputationType computation_type, blas::AlgorithmType algorithm,
2257 blas::ProfileResult *output_profile_result) {
2258 return DoBlasGemmWithAlgorithmImpl(
2259 stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2260 computation_type, algorithm, output_profile_result);
2261 }
2262
2263 bool CUDABlas::DoBlasGemmWithAlgorithm(
2264 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2265 uint64 n, uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha,
2266 const DeviceMemory<Eigen::half> &a, int lda,
2267 const DeviceMemory<Eigen::half> &b, int ldb,
2268 const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c,
2269 int ldc, blas::ComputationType computation_type,
2270 blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
2271 if (computation_type == blas::ComputationType::kF32) {
2272 if (alpha.is_pointer() || beta.is_pointer()) {
2273 // We cannot easily convert a pointer to f16 memory to a pointer to f32
2274 // memory from here, so we don't support this for now.
2275 // TODO(akuegel): Investigate whether we can do the conversion before
2276 // calling DoBlasGemmWithAlgorithm.
2277 return false;
2278 }
2279 HostOrDeviceScalar<float> float_alpha(static_cast<float>(alpha.value()));
2280 HostOrDeviceScalar<float> float_beta(static_cast<float>(beta.value()));
2281 return DoBlasGemmWithAlgorithmImpl(
2282 stream, transa, transb, m, n, k, float_alpha, a, lda, b, ldb,
2283 float_beta, c, ldc, computation_type, algorithm, output_profile_result);
2284 }
2285
2286 CHECK_EQ(computation_type, blas::ComputationType::kF16);
2287 return DoBlasGemmWithAlgorithmImpl(
2288 stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2289 computation_type, algorithm, output_profile_result);
2290 }
2291
2292 bool CUDABlas::DoBlasGemmWithAlgorithm(
2293 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2294 uint64 n, uint64 k, const HostOrDeviceScalar<float> &alpha,
2295 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,
2296 int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c,
2297 int ldc, blas::ComputationType computation_type,
2298 blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
2299 return DoBlasGemmWithAlgorithmImpl(
2300 stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2301 computation_type, algorithm, output_profile_result);
2302 }
2303
2304 bool CUDABlas::DoBlasGemmWithAlgorithm(
2305 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2306 uint64 n, uint64 k, const HostOrDeviceScalar<double> &alpha,
2307 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,
2308 int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c,
2309 int ldc, blas::ComputationType computation_type,
2310 blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
2311 return DoBlasGemmWithAlgorithmImpl(
2312 stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2313 computation_type, algorithm, output_profile_result);
2314 }
2315
2316 bool CUDABlas::DoBlasGemmWithAlgorithm(
2317 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2318 uint64 n, uint64 k, const HostOrDeviceScalar<std::complex<float>> &alpha,
2319 const DeviceMemory<std::complex<float>> &a, int lda,
2320 const DeviceMemory<std::complex<float>> &b, int ldb,
2321 const HostOrDeviceScalar<std::complex<float>> &beta,
2322 DeviceMemory<std::complex<float>> *c, int ldc,
2323 blas::ComputationType computation_type, blas::AlgorithmType algorithm,
2324 blas::ProfileResult *output_profile_result) {
2325 return DoBlasGemmWithAlgorithmImpl(
2326 stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2327 computation_type, algorithm, output_profile_result);
2328 }
2329
2330 bool CUDABlas::DoBlasGemmWithAlgorithm(
2331 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2332 uint64 n, uint64 k, const HostOrDeviceScalar<std::complex<double>> &alpha,
2333 const DeviceMemory<std::complex<double>> &a, int lda,
2334 const DeviceMemory<std::complex<double>> &b, int ldb,
2335 const HostOrDeviceScalar<std::complex<double>> &beta,
2336 DeviceMemory<std::complex<double>> *c, int ldc,
2337 blas::ComputationType computation_type, blas::AlgorithmType algorithm,
2338 blas::ProfileResult *output_profile_result) {
2339 return DoBlasGemmWithAlgorithmImpl(
2340 stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2341 computation_type, algorithm, output_profile_result);
2342 }
2343
2344 template <typename T>
2345 struct HalfAsFloat {
2346 typedef T type;
2347 };
2348
2349 template <>
2350 struct HalfAsFloat<Eigen::half> {
2351 typedef float type;
2352 };
2353
2354 namespace {
2355 // pass-through for non-complex types that don't need conversion to
2356 // cublas-specific type.
2357 template <typename T>
2358 T inline GpuComplexValue(T v) {
2359 return v;
2360 }
2361 } // namespace
2362
2363 template <typename T, typename Scalar, typename FuncT>
2364 port::Status CUDABlas::DoBlasGemmBatchedInternal(
2365 FuncT cublas_func, Stream *stream, blas::Transpose transa,
2366 blas::Transpose transb, uint64 m, uint64 n, uint64 k, Scalar alpha,
2367 const port::ArraySlice<DeviceMemory<T> *> &a_ptrs_to_wrappers, int lda,
2368 const port::ArraySlice<DeviceMemory<T> *> &b_ptrs_to_wrappers, int ldb,
2369 Scalar beta, const port::ArraySlice<DeviceMemory<T> *> &c_ptrs_to_wrappers,
2370 int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
2371 std::vector<T *> a_raw_ptrs, b_raw_ptrs, c_raw_ptrs;
2372 for (int i = 0; i < batch_count; ++i) {
2373 a_raw_ptrs.push_back(static_cast<T *>(a_ptrs_to_wrappers[i]->opaque()));
2374 b_raw_ptrs.push_back(static_cast<T *>(b_ptrs_to_wrappers[i]->opaque()));
2375 c_raw_ptrs.push_back(static_cast<T *>(c_ptrs_to_wrappers[i]->opaque()));
2376 }
2377
2378 typedef typename HalfAsFloat<typename GpuComplexT<T>::type>::type CUDA_T;
2379
2380 const size_t size = batch_count * sizeof(CUDA_T *);
2381
2382 // Device-side copy of pointers to matrices.
2383 DeviceMemory<CUDA_T *> a;
2384 DeviceMemory<CUDA_T *> b;
2385 DeviceMemory<CUDA_T *> c;
2386
2387 // If temporary space is allocated for device-side copies of pointers to
2388 // matrices, that temporary space should not be freed until this function
2389 // returns. Although the values for these unique_ptrs are not set here, they
2390 // are declared at this scope so they will be destroyed when the function
2391 // returns.
2392 //
2393 // If a scratch allocator is provided, these pointers will not be used at all.
2394 std::unique_ptr<TemporaryDeviceMemory<CUDA_T *>> a_temporary;
2395 std::unique_ptr<TemporaryDeviceMemory<CUDA_T *>> b_temporary;
2396 std::unique_ptr<TemporaryDeviceMemory<CUDA_T *>> c_temporary;
2397
2398 // Decide how to allocate device-side copy of pointers to matrices based on
2399 // whether a scratch allocator was passed.
2400 if (scratch_allocator != nullptr) {
2401 SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> a_bytes,
2402 scratch_allocator->AllocateBytes(size));
2403 SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> b_bytes,
2404 scratch_allocator->AllocateBytes(size));
2405 SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> c_bytes,
2406 scratch_allocator->AllocateBytes(size));
2407 a = DeviceMemory<CUDA_T *>(a_bytes);
2408 b = DeviceMemory<CUDA_T *>(b_bytes);
2409 c = DeviceMemory<CUDA_T *>(c_bytes);
2410 } else {
2411 SE_ASSIGN_OR_RETURN(a_temporary,
2412 stream->AllocateTemporaryArray<CUDA_T *>(batch_count));
2413 SE_ASSIGN_OR_RETURN(b_temporary,
2414 stream->AllocateTemporaryArray<CUDA_T *>(batch_count));
2415 SE_ASSIGN_OR_RETURN(c_temporary,
2416 stream->AllocateTemporaryArray<CUDA_T *>(batch_count));
2417 a = DeviceMemory<CUDA_T *>(*a_temporary->mutable_device_memory());
2418 b = DeviceMemory<CUDA_T *>(*b_temporary->mutable_device_memory());
2419 c = DeviceMemory<CUDA_T *>(*c_temporary->mutable_device_memory());
2420 }
2421
2422 if (!stream->ThenMemcpy(&a, a_raw_ptrs.data(), size).ok() ||
2423 !stream->ThenMemcpy(&b, b_raw_ptrs.data(), size).ok() ||
2424 !stream->ThenMemcpy(&c, c_raw_ptrs.data(), size).ok()) {
2425 return port::Status(port::error::INTERNAL,
2426 "failed to copy memory from host to device in "
2427 "CUDABlas::DoBlasGemmBatched");
2428 }
2429
2430 cudaDataType_t data_type = CUDADataType<T>::type;
2431
2432 #if CUDA_VERSION >= 9010
2433 int cc_major, cc_minor;
2434 if (stream->parent()->GetDeviceDescription().cuda_compute_capability(
2435 &cc_major, &cc_minor) &&
2436 cc_major >= 5) {
2437 cublasMath_t math_type;
2438 cublasGemmAlgo_t algo;
2439 if (data_type == CUDA_R_16F) {
2440 #if CUDA_VERSION < 11000
2441 math_type = CUBLAS_TENSOR_OP_MATH;
2442 #else
2443 math_type = CUBLAS_DEFAULT_MATH;
2444 #endif
2445 algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
2446 #if CUBLAS_VER_MAJOR >= 11
2447 } else if (data_type == CUDA_R_32F) {
2448 // DoBlassInternalImpl will switch math_type back to CUBLAS_DEFAULT_MATH
2449 // if TensorFloat-32 is disabled.
2450 math_type = CUBLAS_TF32_TENSOR_OP_MATH;
2451 algo = tensorflow::tensor_float_32_execution_enabled()
2452 ? CUBLAS_GEMM_DFALT_TENSOR_OP
2453 : CUBLAS_GEMM_DFALT;
2454 #endif
2455 } else {
2456 math_type = CUBLAS_DEFAULT_MATH;
2457 algo = CUBLAS_GEMM_DFALT;
2458 }
2459 cudaDataType_t compute_type =
2460 (data_type == CUDA_R_16F ? CUDA_R_32F : data_type);
2461 const void **a_void_ptrs = reinterpret_cast<const void **>(
2462 const_cast<const CUDA_T **>(GpuMemory(a)));
2463 const void **b_void_ptrs = reinterpret_cast<const void **>(
2464 const_cast<const CUDA_T **>(GpuMemory(b)));
2465 void **c_void_ptrs =
2466 reinterpret_cast<void **>(const_cast<CUDA_T **>(GpuMemory(c)));
2467 bool ok;
2468 ok = DoBlasInternalImpl(
2469 AS_LAMBDA(cublasGemmBatchedEx), stream, true /* = pointer_mode_host */,
2470 true /* = err_on_failure */, math_type, CUDABlasTranspose(transa),
2471 CUDABlasTranspose(transb), m, n, k, &alpha, a_void_ptrs, data_type, lda,
2472 b_void_ptrs, data_type, ldb, &beta, c_void_ptrs, data_type, ldc,
2473 batch_count, compute_type, algo);
2474 if (ok) {
2475 return port::Status::OK();
2476 }
2477 return port::Status(port::error::INTERNAL,
2478 "failed BLAS call, see log for details");
2479 }
2480 #endif
2481 // either CUDA_VERSION < 9.1 or SM < 5.0
2482 if (data_type != CUDA_R_16F) {
2483 auto cb_alpha = GpuComplexValue(alpha);
2484 auto cb_beta = GpuComplexValue(beta);
2485 bool ok = DoBlasInternal(
2486 cublas_func, stream, true /* = pointer_mode_host */,
2487 CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
2488 GpuComplex(&cb_alpha), const_cast<const CUDA_T **>(GpuMemory(a)), lda,
2489 const_cast<const CUDA_T **>(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
2490 const_cast<CUDA_T **>(GpuMemory(c)), ldc, batch_count);
2491 if (ok) {
2492 return port::Status::OK();
2493 }
2494 return port::Status(port::error::INTERNAL,
2495 "failed BLAS call, see log for details");
2496 } else {
2497 // Fall back to a loop for fp16
2498 for (int b = 0; b < batch_count; ++b) {
2499 const DeviceMemory<T> &a_matrix = *a_ptrs_to_wrappers[b];
2500 const DeviceMemory<T> &b_matrix = *b_ptrs_to_wrappers[b];
2501 DeviceMemory<T> *c_matrix = c_ptrs_to_wrappers[b];
2502 bool ok = DoBlasGemm(stream, transa, transb, m, n, k, alpha, a_matrix,
2503 lda, b_matrix, ldb, beta, c_matrix, ldc);
2504 if (!ok) {
2505 return port::Status(port::error::INTERNAL,
2506 "failed BLAS call, see log for details");
2507 }
2508 }
2509 return port::Status::OK();
2510 }
2511 }
2512
2513 bool CUDABlas::DoBlasGemmBatched(
2514 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2515 uint64 n, uint64 k, float alpha,
2516 const port::ArraySlice<DeviceMemory<Eigen::half> *> &a_array, int lda,
2517 const port::ArraySlice<DeviceMemory<Eigen::half> *> &b_array, int ldb,
2518 float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c_array,
2519 int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
2520 // Note: The func passed here (cublasSgemmBatched) is not actually called,
2521 // due to special handling of fp16 inside DoBlasGemmBatchedInternal.
2522 port::Status status = DoBlasGemmBatchedInternal(
2523 cublasSgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
2524 b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
2525 if (!status.ok()) {
2526 LOG(ERROR) << status;
2527 }
2528 return status.ok();
2529 }
2530
2531 bool CUDABlas::DoBlasGemmBatched(
2532 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2533 uint64 n, uint64 k, float alpha,
2534 const port::ArraySlice<DeviceMemory<float> *> &a_array, int lda,
2535 const port::ArraySlice<DeviceMemory<float> *> &b_array, int ldb, float beta,
2536 const port::ArraySlice<DeviceMemory<float> *> &c_array, int ldc,
2537 int batch_count, ScratchAllocator *scratch_allocator) {
2538 port::Status status = DoBlasGemmBatchedInternal(
2539 cublasSgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
2540 b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
2541 if (!status.ok()) {
2542 LOG(ERROR) << status;
2543 }
2544 return status.ok();
2545 }
2546
2547 bool CUDABlas::DoBlasGemmBatched(
2548 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2549 uint64 n, uint64 k, double alpha,
2550 const port::ArraySlice<DeviceMemory<double> *> &a_array, int lda,
2551 const port::ArraySlice<DeviceMemory<double> *> &b_array, int ldb,
2552 double beta, const port::ArraySlice<DeviceMemory<double> *> &c_array,
2553 int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
2554 port::Status status = DoBlasGemmBatchedInternal(
2555 cublasDgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
2556 b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
2557 if (!status.ok()) {
2558 LOG(ERROR) << status;
2559 }
2560 return status.ok();
2561 }
2562
2563 bool CUDABlas::DoBlasGemmBatched(
2564 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2565 uint64 n, uint64 k, std::complex<float> alpha,
2566 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a_array,
2567 int lda,
2568 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b_array,
2569 int ldb, std::complex<float> beta,
2570 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c_array,
2571 int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
2572 port::Status status = DoBlasGemmBatchedInternal(
2573 cublasCgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
2574 b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
2575 if (!status.ok()) {
2576 LOG(ERROR) << status;
2577 }
2578 return status.ok();
2579 }
2580
2581 bool CUDABlas::DoBlasGemmBatched(
2582 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2583 uint64 n, uint64 k, std::complex<double> alpha,
2584 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a_array,
2585 int lda,
2586 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b_array,
2587 int ldb, std::complex<double> beta,
2588 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c_array,
2589 int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
2590 port::Status status = DoBlasGemmBatchedInternal(
2591 cublasZgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
2592 b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
2593 if (!status.ok()) {
2594 LOG(ERROR) << status;
2595 }
2596 return status.ok();
2597 }
2598
2599 bool CUDABlas::DoBlasGemmStridedBatched(
2600 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2601 uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
2602 int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb,
2603 int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc,
2604 int64 stride_c, int batch_count) {
2605 #if CUDA_VERSION >= 9010
2606 int cc_major, cc_minor;
2607 if (stream->parent()->GetDeviceDescription().cuda_compute_capability(
2608 &cc_major, &cc_minor) &&
2609 cc_major >= 5) {
2610 cublasGemmAlgo_t algo =
2611 (cc_major >= 7 ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT);
2612 #if CUDA_VERSION < 11000
2613 cublasMath_t math_type = CUBLAS_TENSOR_OP_MATH;
2614 #else
2615 cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
2616 #endif
2617 bool ok = DoBlasInternalImpl(
2618 AS_LAMBDA(cublasGemmStridedBatchedEx), stream,
2619 true /* = pointer_mode_host */, true /* = err_on_failure */, math_type,
2620 CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
2621 GpuMemory(a), CUDA_R_16F, lda, stride_a, GpuMemory(b), CUDA_R_16F, ldb,
2622 stride_b, &beta, GpuMemoryMutable(c), CUDA_R_16F, ldc, stride_c,
2623 batch_count, CUDA_R_32F, algo);
2624 if (ok) {
2625 return true;
2626 }
2627 LOG(ERROR) << "failed BLAS call, see log for details";
2628 return false;
2629 }
2630 #endif
2631 // Either CUDA_VERSION < 9.1 or SM < 5.0. Fall back to a loop.
2632 for (int batch = 0; batch < batch_count; ++batch) {
2633 const auto *a_matrix =
2634 reinterpret_cast<const __half *>(GpuMemory(a) + batch * stride_a);
2635 const auto *b_matrix =
2636 reinterpret_cast<const __half *>(GpuMemory(b) + batch * stride_b);
2637 auto *c_matrix =
2638 reinterpret_cast<__half *>(GpuMemoryMutable(c) + batch * stride_c);
2639 bool ok = DoBlasInternalImpl(
2640 cublasSgemmEx, stream, true /* = pointer_mode_host */,
2641 true /* = err_on_failure= */, CUBLAS_DEFAULT_MATH,
2642 CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
2643 a_matrix, SE_CUDA_DATA_HALF, lda, b_matrix, SE_CUDA_DATA_HALF, ldb,
2644 &beta, c_matrix, SE_CUDA_DATA_HALF, ldc);
2645 if (!ok) {
2646 LOG(ERROR) << "failed BLAS call, see log for details";
2647 return false;
2648 }
2649 }
2650 return true;
2651 }
2652
2653 bool CUDABlas::DoBlasGemmStridedBatched(
2654 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2655 uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
2656 int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
2657 float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
2658 int batch_count) {
2659 #if CUDA_VERSION < 11000
2660 cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
2661 #else
2662 cublasMath_t math_type = CUBLAS_TF32_TENSOR_OP_MATH;
2663 #endif
2664 return DoBlasInternalImpl(
2665 cublasSgemmStridedBatched, stream, true /* = pointer_mode_host */,
2666 true /* = err_on_failure */, math_type, CUDABlasTranspose(transa),
2667 CUDABlasTranspose(transb), m, n, k, &alpha, GpuMemory(a), lda, stride_a,
2668 GpuMemory(b), ldb, stride_b, &beta, GpuMemoryMutable(c), ldc, stride_c,
2669 batch_count);
2670 }
2671
2672 bool CUDABlas::DoBlasGemmStridedBatched(
2673 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2674 uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
2675 int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
2676 double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
2677 int batch_count) {
2678 return DoBlasInternal(
2679 cublasDgemmStridedBatched, stream, true /* = pointer_mode_host */,
2680 CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
2681 GpuMemory(a), lda, stride_a, GpuMemory(b), ldb, stride_b, &beta,
2682 GpuMemoryMutable(c), ldc, stride_c, batch_count);
2683 }
2684
2685 bool CUDABlas::DoBlasGemmStridedBatched(
2686 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2687 uint64 n, uint64 k, std::complex<float> alpha,
2688 const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
2689 const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
2690 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
2691 int64 stride_c, int batch_count) {
2692 auto cb_alpha = GpuComplexValue(alpha);
2693 auto cb_beta = GpuComplexValue(beta);
2694 return DoBlasInternal(
2695 cublasCgemmStridedBatched, stream, true /* = pointer_mode_host */,
2696 CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
2697 GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda, stride_a,
2698 GpuComplex(GpuMemory(b)), ldb, stride_b, GpuComplex(&cb_beta),
2699 GpuComplex(GpuMemoryMutable(c)), ldc, stride_c, batch_count);
2700 }
2701
2702 bool CUDABlas::DoBlasGemmStridedBatched(
2703 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2704 uint64 n, uint64 k, std::complex<double> alpha,
2705 const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
2706 const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
2707 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
2708 int64 stride_c, int batch_count) {
2709 auto cb_alpha = GpuComplexValue(alpha);
2710 auto cb_beta = GpuComplexValue(beta);
2711 return DoBlasInternal(
2712 cublasZgemmStridedBatched, stream, true /* = pointer_mode_host */,
2713 CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
2714 GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda, stride_a,
2715 GpuComplex(GpuMemory(b)), ldb, stride_b, GpuComplex(&cb_beta),
2716 GpuComplex(GpuMemoryMutable(c)), ldc, stride_c, batch_count);
2717 }
2718
2719 bool CUDABlas::DoBlasHemm(Stream *stream, blas::Side side,
2720 blas::UpperLower uplo, uint64 m, uint64 n,
2721 std::complex<float> alpha,
2722 const DeviceMemory<std::complex<float>> &a, int lda,
2723 const DeviceMemory<std::complex<float>> &b, int ldb,
2724 std::complex<float> beta,
2725 DeviceMemory<std::complex<float>> *c, int ldc) {
2726 auto cb_alpha = GpuComplexValue(alpha);
2727 auto cb_beta = GpuComplexValue(beta);
2728 return DoBlasInternal(cublasChemm, stream, true /* = pointer_mode_host */,
2729 CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n,
2730 GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2731 GpuComplex(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
2732 GpuComplex(GpuMemoryMutable(c)), ldc);
2733 }
2734
2735 bool CUDABlas::DoBlasHemm(Stream *stream, blas::Side side,
2736 blas::UpperLower uplo, uint64 m, uint64 n,
2737 std::complex<double> alpha,
2738 const DeviceMemory<std::complex<double>> &a, int lda,
2739 const DeviceMemory<std::complex<double>> &b, int ldb,
2740 std::complex<double> beta,
2741 DeviceMemory<std::complex<double>> *c, int ldc) {
2742 auto cb_alpha = GpuComplexValue(alpha);
2743 auto cb_beta = GpuComplexValue(beta);
2744 return DoBlasInternal(cublasZhemm, stream, true /* = pointer_mode_host */,
2745 CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n,
2746 GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2747 GpuComplex(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
2748 GpuComplex(GpuMemoryMutable(c)), ldc);
2749 }
2750
2751 bool CUDABlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo,
2752 blas::Transpose trans, uint64 n, uint64 k,
2753 float alpha,
2754 const DeviceMemory<std::complex<float>> &a, int lda,
2755 float beta, DeviceMemory<std::complex<float>> *c,
2756 int ldc) {
2757 return DoBlasInternal(cublasCherk, stream, true /* = pointer_mode_host */,
2758 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2759 k, &alpha, GpuComplex(GpuMemory(a)), lda, &beta,
2760 GpuComplex(GpuMemoryMutable(c)), ldc);
2761 }
2762
2763 bool CUDABlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo,
2764 blas::Transpose trans, uint64 n, uint64 k,
2765 double alpha,
2766 const DeviceMemory<std::complex<double>> &a, int lda,
2767 double beta, DeviceMemory<std::complex<double>> *c,
2768 int ldc) {
2769 return DoBlasInternal(cublasZherk, stream, true /* = pointer_mode_host */,
2770 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2771 k, &alpha, GpuComplex(GpuMemory(a)), lda, &beta,
2772 GpuComplex(GpuMemoryMutable(c)), ldc);
2773 }
2774
2775 bool CUDABlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
2776 blas::Transpose trans, uint64 n, uint64 k,
2777 std::complex<float> alpha,
2778 const DeviceMemory<std::complex<float>> &a, int lda,
2779 const DeviceMemory<std::complex<float>> &b, int ldb,
2780 float beta, DeviceMemory<std::complex<float>> *c,
2781 int ldc) {
2782 auto cb_alpha = GpuComplexValue(alpha);
2783 return DoBlasInternal(cublasCher2k, stream, true /* = pointer_mode_host */,
2784 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2785 k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2786 GpuComplex(GpuMemory(b)), ldb, &beta,
2787 GpuComplex(GpuMemoryMutable(c)), ldc);
2788 }
2789
2790 bool CUDABlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
2791 blas::Transpose trans, uint64 n, uint64 k,
2792 std::complex<double> alpha,
2793 const DeviceMemory<std::complex<double>> &a, int lda,
2794 const DeviceMemory<std::complex<double>> &b, int ldb,
2795 double beta, DeviceMemory<std::complex<double>> *c,
2796 int ldc) {
2797 auto cb_alpha = GpuComplexValue(alpha);
2798 return DoBlasInternal(cublasZher2k, stream, true /* = pointer_mode_host */,
2799 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2800 k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2801 GpuComplex(GpuMemory(b)), ldb, &beta,
2802 GpuComplex(GpuMemoryMutable(c)), ldc);
2803 }
2804
2805 bool CUDABlas::DoBlasSymm(Stream *stream, blas::Side side,
2806 blas::UpperLower uplo, uint64 m, uint64 n,
2807 float alpha, const DeviceMemory<float> &a, int lda,
2808 const DeviceMemory<float> &b, int ldb, float beta,
2809 DeviceMemory<float> *c, int ldc) {
2810 return DoBlasInternal(cublasSsymm, stream, true /* = pointer_mode_host */,
2811 CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n,
2812 &alpha, GpuMemory(a), lda, GpuMemory(b), ldb, &beta,
2813 GpuMemoryMutable(c), ldc);
2814 }
2815
2816 bool CUDABlas::DoBlasSymm(Stream *stream, blas::Side side,
2817 blas::UpperLower uplo, uint64 m, uint64 n,
2818 double alpha, const DeviceMemory<double> &a, int lda,
2819 const DeviceMemory<double> &b, int ldb, double beta,
2820 DeviceMemory<double> *c, int ldc) {
2821 return DoBlasInternal(cublasDsymm, stream, true /* = pointer_mode_host */,
2822 CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n,
2823 &alpha, GpuMemory(a), lda, GpuMemory(b), ldb, &beta,
2824 GpuMemoryMutable(c), ldc);
2825 }
2826
2827 bool CUDABlas::DoBlasSymm(Stream *stream, blas::Side side,
2828 blas::UpperLower uplo, uint64 m, uint64 n,
2829 std::complex<float> alpha,
2830 const DeviceMemory<std::complex<float>> &a, int lda,
2831 const DeviceMemory<std::complex<float>> &b, int ldb,
2832 std::complex<float> beta,
2833 DeviceMemory<std::complex<float>> *c, int ldc) {
2834 auto cb_alpha = GpuComplexValue(alpha);
2835 auto cb_beta = GpuComplexValue(beta);
2836 return DoBlasInternal(cublasCsymm, stream, true /* = pointer_mode_host */,
2837 CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n,
2838 GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2839 GpuComplex(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
2840 GpuComplex(GpuMemoryMutable(c)), ldc);
2841 }
2842
2843 bool CUDABlas::DoBlasSymm(Stream *stream, blas::Side side,
2844 blas::UpperLower uplo, uint64 m, uint64 n,
2845 std::complex<double> alpha,
2846 const DeviceMemory<std::complex<double>> &a, int lda,
2847 const DeviceMemory<std::complex<double>> &b, int ldb,
2848 std::complex<double> beta,
2849 DeviceMemory<std::complex<double>> *c, int ldc) {
2850 auto cb_alpha = GpuComplexValue(alpha);
2851 auto cb_beta = GpuComplexValue(beta);
2852 return DoBlasInternal(cublasZsymm, stream, true /* = pointer_mode_host */,
2853 CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n,
2854 GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2855 GpuComplex(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
2856 GpuComplex(GpuMemoryMutable(c)), ldc);
2857 }
2858
2859 bool CUDABlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
2860 blas::Transpose trans, uint64 n, uint64 k,
2861 float alpha, const DeviceMemory<float> &a, int lda,
2862 float beta, DeviceMemory<float> *c, int ldc) {
2863 return DoBlasInternal(cublasSsyrk, stream, true /* = pointer_mode_host */,
2864 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2865 k, &alpha, GpuMemory(a), lda, &beta,
2866 GpuMemoryMutable(c), ldc);
2867 }
2868
2869 bool CUDABlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
2870 blas::Transpose trans, uint64 n, uint64 k,
2871 double alpha, const DeviceMemory<double> &a, int lda,
2872 double beta, DeviceMemory<double> *c, int ldc) {
2873 return DoBlasInternal(cublasDsyrk, stream, true /* = pointer_mode_host */,
2874 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2875 k, &alpha, GpuMemory(a), lda, &beta,
2876 GpuMemoryMutable(c), ldc);
2877 }
2878
2879 bool CUDABlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
2880 blas::Transpose trans, uint64 n, uint64 k,
2881 std::complex<float> alpha,
2882 const DeviceMemory<std::complex<float>> &a, int lda,
2883 std::complex<float> beta,
2884 DeviceMemory<std::complex<float>> *c, int ldc) {
2885 auto cb_alpha = GpuComplexValue(alpha);
2886 auto cb_beta = GpuComplexValue(beta);
2887 return DoBlasInternal(cublasCsyrk, stream, true /* = pointer_mode_host */,
2888 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2889 k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2890 GpuComplex(&cb_beta), GpuComplex(GpuMemoryMutable(c)),
2891 ldc);
2892 }
2893
2894 bool CUDABlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
2895 blas::Transpose trans, uint64 n, uint64 k,
2896 std::complex<double> alpha,
2897 const DeviceMemory<std::complex<double>> &a, int lda,
2898 std::complex<double> beta,
2899 DeviceMemory<std::complex<double>> *c, int ldc) {
2900 auto cb_alpha = GpuComplexValue(alpha);
2901 auto cb_beta = GpuComplexValue(beta);
2902 return DoBlasInternal(cublasZsyrk, stream, true /* = pointer_mode_host */,
2903 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2904 k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2905 GpuComplex(&cb_beta), GpuComplex(GpuMemoryMutable(c)),
2906 ldc);
2907 }
2908
2909 bool CUDABlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
2910 blas::Transpose trans, uint64 n, uint64 k,
2911 float alpha, const DeviceMemory<float> &a, int lda,
2912 const DeviceMemory<float> &b, int ldb, float beta,
2913 DeviceMemory<float> *c, int ldc) {
2914 return DoBlasInternal(cublasSsyr2k, stream, true /* = pointer_mode_host */,
2915 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2916 k, &alpha, GpuMemory(a), lda, GpuMemory(b), ldb, &beta,
2917 GpuMemoryMutable(c), ldc);
2918 }
2919
2920 bool CUDABlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
2921 blas::Transpose trans, uint64 n, uint64 k,
2922 double alpha, const DeviceMemory<double> &a, int lda,
2923 const DeviceMemory<double> &b, int ldb, double beta,
2924 DeviceMemory<double> *c, int ldc) {
2925 return DoBlasInternal(cublasDsyr2k, stream, true /* = pointer_mode_host */,
2926 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2927 k, &alpha, GpuMemory(a), lda, GpuMemory(b), ldb, &beta,
2928 GpuMemoryMutable(c), ldc);
2929 }
2930
2931 bool CUDABlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
2932 blas::Transpose trans, uint64 n, uint64 k,
2933 std::complex<float> alpha,
2934 const DeviceMemory<std::complex<float>> &a, int lda,
2935 const DeviceMemory<std::complex<float>> &b, int ldb,
2936 std::complex<float> beta,
2937 DeviceMemory<std::complex<float>> *c, int ldc) {
2938 auto cb_alpha = GpuComplexValue(alpha);
2939 auto cb_beta = GpuComplexValue(beta);
2940 return DoBlasInternal(cublasCsyr2k, stream, true /* = pointer_mode_host */,
2941 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2942 k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2943 GpuComplex(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
2944 GpuComplex(GpuMemoryMutable(c)), ldc);
2945 }
2946
2947 bool CUDABlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
2948 blas::Transpose trans, uint64 n, uint64 k,
2949 std::complex<double> alpha,
2950 const DeviceMemory<std::complex<double>> &a, int lda,
2951 const DeviceMemory<std::complex<double>> &b, int ldb,
2952 std::complex<double> beta,
2953 DeviceMemory<std::complex<double>> *c, int ldc) {
2954 auto cb_alpha = GpuComplexValue(alpha);
2955 auto cb_beta = GpuComplexValue(beta);
2956 return DoBlasInternal(cublasZsyr2k, stream, true /* = pointer_mode_host */,
2957 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
2958 k, GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2959 GpuComplex(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
2960 GpuComplex(GpuMemoryMutable(c)), ldc);
2961 }
2962
2963 bool CUDABlas::DoBlasTrmm(Stream *stream, blas::Side side,
2964 blas::UpperLower uplo, blas::Transpose transa,
2965 blas::Diagonal diag, uint64 m, uint64 n, float alpha,
2966 const DeviceMemory<float> &a, int lda,
2967 DeviceMemory<float> *b, int ldb) {
2968 return DoBlasInternal(cublasStrmm, stream, true /* = pointer_mode_host */,
2969 CUDABlasSide(side), CUDABlasUpperLower(uplo),
2970 CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
2971 &alpha, GpuMemory(a), lda, GpuMemoryMutable(b), ldb,
2972 GpuMemoryMutable(b), ldb);
2973 }
2974
2975 bool CUDABlas::DoBlasTrmm(Stream *stream, blas::Side side,
2976 blas::UpperLower uplo, blas::Transpose transa,
2977 blas::Diagonal diag, uint64 m, uint64 n, double alpha,
2978 const DeviceMemory<double> &a, int lda,
2979 DeviceMemory<double> *b, int ldb) {
2980 return DoBlasInternal(cublasDtrmm, stream, true /* = pointer_mode_host */,
2981 CUDABlasSide(side), CUDABlasUpperLower(uplo),
2982 CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
2983 &alpha, GpuMemory(a), lda, GpuMemoryMutable(b), ldb,
2984 GpuMemoryMutable(b), ldb);
2985 }
2986
2987 bool CUDABlas::DoBlasTrmm(Stream *stream, blas::Side side,
2988 blas::UpperLower uplo, blas::Transpose transa,
2989 blas::Diagonal diag, uint64 m, uint64 n,
2990 std::complex<float> alpha,
2991 const DeviceMemory<std::complex<float>> &a, int lda,
2992 DeviceMemory<std::complex<float>> *b, int ldb) {
2993 auto cb_alpha = GpuComplexValue(alpha);
2994 return DoBlasInternal(cublasCtrmm, stream, true /* = pointer_mode_host */,
2995 CUDABlasSide(side), CUDABlasUpperLower(uplo),
2996 CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
2997 GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
2998 GpuComplex(GpuMemoryMutable(b)), ldb,
2999 GpuComplex(GpuMemoryMutable(b)), ldb);
3000 }
3001
3002 bool CUDABlas::DoBlasTrmm(Stream *stream, blas::Side side,
3003 blas::UpperLower uplo, blas::Transpose transa,
3004 blas::Diagonal diag, uint64 m, uint64 n,
3005 std::complex<double> alpha,
3006 const DeviceMemory<std::complex<double>> &a, int lda,
3007 DeviceMemory<std::complex<double>> *b, int ldb) {
3008 auto cb_alpha = GpuComplexValue(alpha);
3009 return DoBlasInternal(cublasZtrmm, stream, true /* = pointer_mode_host */,
3010 CUDABlasSide(side), CUDABlasUpperLower(uplo),
3011 CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
3012 GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
3013 GpuComplex(GpuMemoryMutable(b)), ldb,
3014 GpuComplex(GpuMemoryMutable(b)), ldb);
3015 }
3016
3017 bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
3018 blas::UpperLower uplo, blas::Transpose transa,
3019 blas::Diagonal diag, uint64 m, uint64 n, float alpha,
3020 const DeviceMemory<float> &a, int lda,
3021 DeviceMemory<float> *b, int ldb) {
3022 return DoBlasInternal(cublasStrsm, stream, true /* = pointer_mode_host */,
3023 CUDABlasSide(side), CUDABlasUpperLower(uplo),
3024 CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
3025 &alpha, GpuMemory(a), lda, GpuMemoryMutable(b), ldb);
3026 }
3027
3028 bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
3029 blas::UpperLower uplo, blas::Transpose transa,
3030 blas::Diagonal diag, uint64 m, uint64 n, double alpha,
3031 const DeviceMemory<double> &a, int lda,
3032 DeviceMemory<double> *b, int ldb) {
3033 return DoBlasInternal(cublasDtrsm, stream, true /* = pointer_mode_host */,
3034 CUDABlasSide(side), CUDABlasUpperLower(uplo),
3035 CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
3036 &alpha, GpuMemory(a), lda, GpuMemoryMutable(b), ldb);
3037 }
3038
3039 bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
3040 blas::UpperLower uplo, blas::Transpose transa,
3041 blas::Diagonal diag, uint64 m, uint64 n,
3042 std::complex<float> alpha,
3043 const DeviceMemory<std::complex<float>> &a, int lda,
3044 DeviceMemory<std::complex<float>> *b, int ldb) {
3045 auto cb_alpha = GpuComplexValue(alpha);
3046 return DoBlasInternal(cublasCtrsm, stream, true /* = pointer_mode_host */,
3047 CUDABlasSide(side), CUDABlasUpperLower(uplo),
3048 CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
3049 GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
3050 GpuComplex(GpuMemoryMutable(b)), ldb);
3051 }
3052
3053 bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
3054 blas::UpperLower uplo, blas::Transpose transa,
3055 blas::Diagonal diag, uint64 m, uint64 n,
3056 std::complex<double> alpha,
3057 const DeviceMemory<std::complex<double>> &a, int lda,
3058 DeviceMemory<std::complex<double>> *b, int ldb) {
3059 auto cb_alpha = GpuComplexValue(alpha);
3060 return DoBlasInternal(cublasZtrsm, stream, true /* = pointer_mode_host */,
3061 CUDABlasSide(side), CUDABlasUpperLower(uplo),
3062 CUDABlasTranspose(transa), CUDABlasDiagonal(diag), m, n,
3063 GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
3064 GpuComplex(GpuMemoryMutable(b)), ldb);
3065 }
3066
3067 // We only use cublasLt from CUDA 11.0 onward.
3068 #if CUDA_VERSION >= 11000
3069
3070 namespace {
3071
3072 template <typename T>
3073 inline port::Status SetCublasLtAttr(cublasLtMatrixLayout_t handle,
3074 cublasLtMatrixLayoutAttribute_t attr,
3075 const T &value) {
3076 cublasStatus_t status =
3077 cublasLtMatrixLayoutSetAttribute(handle, attr, &value, sizeof(T));
3078 if (status != CUBLAS_STATUS_SUCCESS) {
3079 return port::Status(
3080 port::error::INTERNAL,
3081 absl::StrCat("cublasLtMatrixLayoutSetAttribute(attr=", attr,
3082 ", value=", value, ") failed: ", ToString(status)));
3083 }
3084 return port::Status::OK();
3085 }
3086
3087 template <typename T>
3088 inline port::Status SetCublasLtAttr(cublasLtMatmulAlgo_t *handle,
3089 cublasLtMatmulAlgoConfigAttributes_t attr,
3090 const T &value) {
3091 cublasStatus_t status =
3092 cublasLtMatmulAlgoConfigSetAttribute(handle, attr, &value, sizeof(T));
3093 if (status != CUBLAS_STATUS_SUCCESS) {
3094 return port::Status(
3095 port::error::INTERNAL,
3096 absl::StrCat("cublasLtMatmulAlgoConfigSetAttribute(attr=", attr,
3097 ", value=", value, ") failed: ", ToString(status)));
3098 }
3099 return port::Status::OK();
3100 }
3101
3102 template <typename T>
3103 inline port::Status SetCublasLtAttr(cublasLtMatmulPreference_t handle,
3104 cublasLtMatmulPreferenceAttributes_t attr,
3105 const T &value) {
3106 cublasStatus_t status =
3107 cublasLtMatmulPreferenceSetAttribute(handle, attr, &value, sizeof(value));
3108 if (status != CUBLAS_STATUS_SUCCESS) {
3109 return port::Status(
3110 port::error::INTERNAL,
3111 absl::StrCat("cublasLtMatmulPreferenceSetAttribute(attr=", attr,
3112 ", value=", value, ") failed: ", ToString(status)));
3113 }
3114 return port::Status::OK();
3115 }
3116
3117 template <typename T>
3118 inline bool GetCublasLtAttr(const cublasLtMatmulAlgo_t *handle,
3119 cublasLtMatmulAlgoConfigAttributes_t attr,
3120 T *value) {
3121 auto mutable_handle = const_cast<cublasLtMatmulAlgo_t *>(handle);
3122 size_t bytes_written = 0;
3123 return cublasLtMatmulAlgoConfigGetAttribute(mutable_handle, attr, value,
3124 sizeof(T), &bytes_written) ==
3125 CUBLAS_STATUS_SUCCESS &&
3126 bytes_written == sizeof(T);
3127 }
3128
3129 template <typename T>
3130 inline const T &ValueForStrCat(const T &value) {
3131 return value;
3132 }
3133 template <typename T>
3134 inline absl::Hex ValueForStrCat(T *ptr) {
3135 return absl::Hex(reinterpret_cast<uintptr_t>(ptr));
3136 }
3137
3138 template <typename T>
3139 inline port::Status SetCublasLtAttr(cublasLtMatmulDesc_t handle,
3140 cublasLtMatmulDescAttributes_t attr,
3141 const T &value) {
3142 cublasStatus_t status =
3143 cublasLtMatmulDescSetAttribute(handle, attr, &value, sizeof(value));
3144 if (status != CUBLAS_STATUS_SUCCESS) {
3145 return port::Status(
3146 port::error::INTERNAL,
3147 absl::StrCat("cublasLtMatmulDescSetAttribute(attr=", attr, ", value=",
3148 ValueForStrCat(value), ") failed: ", ToString(status)));
3149 }
3150 return port::Status::OK();
3151 }
3152
3153 struct MatmulDescDestroyer {
3154 void operator()(cublasLtMatmulDesc_t matmul_desc) const {
3155 cublasLtMatmulDescDestroy(matmul_desc);
3156 }
3157 };
3158 struct LayoutDestroyer {
3159 void operator()(cublasLtMatrixLayout_t layout) const {
3160 cublasLtMatrixLayoutDestroy(layout);
3161 }
3162 };
3163 struct MatmulPreferenceDestroyer {
3164 void operator()(cublasLtMatmulPreference_t matmul_pref) const {
3165 cublasLtMatmulPreferenceDestroy(matmul_pref);
3166 }
3167 };
3168 using UniqueOpDesc =
3169 std::unique_ptr<std::remove_pointer<cublasLtMatmulDesc_t>::type,
3170 MatmulDescDestroyer>;
3171 using UniqueLayoutDesc =
3172 std::unique_ptr<std::remove_pointer<cublasLtMatrixLayout_t>::type,
3173 LayoutDestroyer>;
3174 using UniqueMatmulPreference =
3175 std::unique_ptr<std::remove_pointer<cublasLtMatmulPreference_t>::type,
3176 MatmulPreferenceDestroyer>;
3177
3178 port::StatusOr<UniqueOpDesc> CreateCublasLtOperationDesc(
3179 blas::ComputationType computation_type, blas::DataType scale_type,
3180 blas::PointerMode pointer_mode, blas::Epilogue epilogue,
3181 blas::Transpose transa, blas::Transpose transb) {
3182 cublasLtMatmulDesc_t desc;
3183 cublasComputeType_t cublas_compute_type =
3184 CUBLASComputationType(computation_type);
3185 cudaDataType_t cuda_scale_type = GetCUDADataType(scale_type);
3186 cublasStatus_t status =
3187 cublasLtMatmulDescCreate(&desc, cublas_compute_type, cuda_scale_type);
3188 if (status != CUBLAS_STATUS_SUCCESS) {
3189 return port::Status(
3190 port::error::INTERNAL,
3191 absl::StrCat("cublasLtMatmulDescCreate(computation_type=",
3192 computation_type, ") failed: ", ToString(status)));
3193 }
3194 UniqueOpDesc unique_desc(desc);
3195 SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_POINTER_MODE,
3196 CUBLASPointerMode(pointer_mode)));
3197 SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_EPILOGUE,
3198 CUBLASEpilogue(epilogue)));
3199 SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSA,
3200 CUDABlasTranspose(transa)));
3201 SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSB,
3202 CUDABlasTranspose(transb)));
3203 return unique_desc;
3204 }
3205
3206 port::StatusOr<UniqueLayoutDesc> CreateCublasLtLayoutDesc(
3207 blas::DataType data_type, uint64 rows, uint64 cols, int64 ld, int64 stride,
3208 int batch_count) {
3209 cublasLtMatrixLayout_t desc;
3210 cublasStatus_t status = cublasLtMatrixLayoutCreate(
3211 &desc, GetCUDADataType(data_type), rows, cols, ld);
3212 if (status != CUBLAS_STATUS_SUCCESS) {
3213 return port::Status(
3214 port::error::INTERNAL,
3215 absl::StrCat("cublasLtMatrixLayoutCreate failed: ", ToString(status)));
3216 }
3217 UniqueLayoutDesc unique_desc(desc);
3218 SE_RETURN_IF_ERROR(
3219 SetCublasLtAttr(desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, batch_count));
3220 SE_RETURN_IF_ERROR(SetCublasLtAttr(
3221 desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stride));
3222 return unique_desc;
3223 }
3224
3225 // Helper function to allocate workspace.
3226 port::Status AllocateWorkspace(void **workspace,
3227 ScratchAllocator *scratch_allocator,
3228 size_t num_bytes) {
3229 SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace_bytes,
3230 scratch_allocator->AllocateBytes(num_bytes));
3231 *workspace = (void *)GpuMemoryMutable(&workspace_bytes);
3232 return port::Status::OK();
3233 }
3234
3235 template <typename T>
3236 blas::ComputationType ToComputationType();
3237 template <>
3238 blas::ComputationType ToComputationType<Eigen::half>() {
3239 return blas::ComputationType::kF16;
3240 }
3241 template <>
3242 blas::ComputationType ToComputationType<float>() {
3243 return blas::ComputationType::kF32;
3244 }
3245 template <>
3246 blas::ComputationType ToComputationType<double>() {
3247 return blas::ComputationType::kF64;
3248 }
3249 template <>
3250 blas::ComputationType ToComputationType<std::complex<float>>() {
3251 return blas::ComputationType::kComplexF32;
3252 }
3253 template <>
3254 blas::ComputationType ToComputationType<std::complex<double>>() {
3255 return blas::ComputationType::kComplexF64;
3256 }
3257
3258 class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
3259 public:
3260 port::Status init(const blas::BlasLtMatmulPlanParams &p) {
3261 params_ = p;
3262 scale_type_ = GetScaleType(p.c_type, p.computation_type);
3263 SE_ASSIGN_OR_RETURN(
3264 op_desc_,
3265 CreateCublasLtOperationDesc(
3266 p.computation_type, GetScaleType(p.c_type, p.computation_type),
3267 p.pointer_mode, p.epilogue, p.transa, p.transb));
3268 uint64 rows_a = p.transa == blas::Transpose::kNoTranspose ? p.m : p.k;
3269 uint64 cols_a = p.transa == blas::Transpose::kNoTranspose ? p.k : p.m;
3270 uint64 rows_b = p.transb == blas::Transpose::kNoTranspose ? p.k : p.n;
3271 uint64 cols_b = p.transb == blas::Transpose::kNoTranspose ? p.n : p.k;
3272 SE_ASSIGN_OR_RETURN(
3273 a_desc_, CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda,
3274 p.stride_a, capped_batch_count()));
3275 SE_ASSIGN_OR_RETURN(
3276 b_desc_, CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb,
3277 p.stride_b, capped_batch_count()));
3278 SE_ASSIGN_OR_RETURN(
3279 c_desc_, CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
3280 capped_batch_count()));
3281 SE_ASSIGN_OR_RETURN(
3282 d_desc_, CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
3283 capped_batch_count()));
3284 remainder_batch_count_ =
3285 p.batch_count > kMaxBatchCount ? p.batch_count % kMaxBatchCount : 0;
3286 if (remainder_batch_count_) {
3287 SE_ASSIGN_OR_RETURN(
3288 a_remainder_desc_,
3289 CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda, p.stride_a,
3290 remainder_batch_count_));
3291 SE_ASSIGN_OR_RETURN(
3292 b_remainder_desc_,
3293 CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb, p.stride_b,
3294 remainder_batch_count_));
3295 SE_ASSIGN_OR_RETURN(
3296 c_remainder_desc_,
3297 CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
3298 remainder_batch_count_));
3299 SE_ASSIGN_OR_RETURN(
3300 d_remainder_desc_,
3301 CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
3302 remainder_batch_count_));
3303 }
3304 return port::Status::OK();
3305 }
3306
3307 cublasLtMatmulDesc_t op_desc() const { return op_desc_.get(); }
3308 cublasLtMatrixLayout_t a_desc() const { return a_desc_.get(); }
3309 cublasLtMatrixLayout_t b_desc() const { return b_desc_.get(); }
3310 cublasLtMatrixLayout_t c_desc() const { return c_desc_.get(); }
3311 cublasLtMatrixLayout_t d_desc() const { return d_desc_.get(); }
3312 cublasLtMatrixLayout_t a_remainder_desc() const {
3313 return a_remainder_desc_.get();
3314 }
3315 cublasLtMatrixLayout_t b_remainder_desc() const {
3316 return b_remainder_desc_.get();
3317 }
3318 cublasLtMatrixLayout_t c_remainder_desc() const {
3319 return c_remainder_desc_.get();
3320 }
3321 cublasLtMatrixLayout_t d_remainder_desc() const {
3322 return d_remainder_desc_.get();
3323 }
3324
3325 const blas::BlasLtMatmulPlanParams ¶ms() const { return params_; }
3326 blas::DataType scale_type() const { return scale_type_; }
3327 blas::DataType ab_type() const override { return params_.ab_type; }
3328 blas::DataType c_type() const override { return params_.c_type; }
3329 int capped_batch_count() const {
3330 return std::min(params_.batch_count, kMaxBatchCount);
3331 }
3332 int remainder_batch_count() const { return remainder_batch_count_; }
3333
3334 // Note: Must be const to satisfy API. This is always called before the plan
3335 // is executed, so the state change is not observed in subsequent executions.
3336 bool SetBiasPointer(const void *bias) const;
3337
3338 private:
3339 // In some cases cublasLt does not support large batch sizes, so we need to
3340 // split up such cases into multiple calls.
3341 // TODO(reedwm): Making this static or constexpr causes a link error with gcc
3342 // in debug mode for unknown reasons. Investigate why.
3343 const int kMaxBatchCount = 65535;
3344 blas::BlasLtMatmulPlanParams params_;
3345 blas::DataType scale_type_;
3346 UniqueOpDesc op_desc_;
3347 // These have batch count set to capped_batch_count().
3348 UniqueLayoutDesc a_desc_;
3349 UniqueLayoutDesc b_desc_;
3350 UniqueLayoutDesc c_desc_;
3351 UniqueLayoutDesc d_desc_;
3352 int remainder_batch_count_;
3353 // These have batch count set to remainder_batch_count_, and are only created
3354 // if params_.batch_count > kMaxBatchSize.
3355 UniqueLayoutDesc a_remainder_desc_;
3356 UniqueLayoutDesc b_remainder_desc_;
3357 UniqueLayoutDesc c_remainder_desc_;
3358 UniqueLayoutDesc d_remainder_desc_;
3359 };
3360
3361 bool CUDABlasLtMatmulPlan::SetBiasPointer(const void *bias) const {
3362 return SetCublasLtAttr(op_desc_.get(), CUBLASLT_MATMUL_DESC_BIAS_POINTER,
3363 bias)
3364 .ok();
3365 }
3366
3367 class CUDABlasLtMatmulAlgorithm final : public blas::IBlasLtMatmulAlgorithm {
3368 public:
3369 CUDABlasLtMatmulAlgorithm(blas::AlgorithmType index,
3370 cublasLtMatmulAlgo_t algo, size_t workspace_size)
3371 : index_(index), algo_(algo), workspace_size_(workspace_size) {}
3372
3373 blas::AlgorithmType index() const override { return index_; }
3374
3375 size_t workspace_size() const override { return workspace_size_; }
3376
3377 const cublasLtMatmulAlgo_t *algo() const { return &algo_; }
3378
3379 int algo_id() const {
3380 int id;
3381 GetCublasLtAttr(&algo_, CUBLASLT_ALGO_CONFIG_ID, &id);
3382 return id;
3383 }
3384
3385 private:
3386 blas::AlgorithmType index_;
3387 cublasLtMatmulAlgo_t algo_;
3388 size_t workspace_size_;
3389 };
3390
3391 port::StatusOr<UniqueMatmulPreference> CreateCublasLtMatmulPreference(
3392 const blas::IBlasLtMatmulPlan *plan, size_t max_workspace_bytes) {
3393 cublasLtMatmulPreference_t preference;
3394 cublasStatus_t status = cublasLtMatmulPreferenceCreate(&preference);
3395 if (status != CUBLAS_STATUS_SUCCESS) {
3396 return port::Status(port::error::INTERNAL,
3397 absl::StrCat("cublasLtMatmulPreferenceCreate failed: ",
3398 ToString(status)));
3399 }
3400 UniqueMatmulPreference unique_preference(preference);
3401 SE_RETURN_IF_ERROR(SetCublasLtAttr(preference,
3402 CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
3403 max_workspace_bytes));
3404
3405 const auto &cuda_plan = *static_cast<const CUDABlasLtMatmulPlan *>(plan);
3406 if (cuda_plan.params().batch_count == 0) {
3407 return unique_preference;
3408 }
3409 // This is a workaround for a known issue in cuBlasLt where the heuristic may
3410 // in rare cases select an algo that does not support the specified stride.
3411 // Specifying the alignment requirements manually like this avoids the issue.
3412 auto get_alignment_bytes = [](int64 stride, blas::DataType dtype) {
3413 return (stride & -stride) * GetDataTypeSizeBytes(dtype);
3414 };
3415 if (cuda_plan.params().stride_a) {
3416 SE_RETURN_IF_ERROR(SetCublasLtAttr(
3417 preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES,
3418 (uint32)get_alignment_bytes(cuda_plan.params().stride_a,
3419 cuda_plan.params().ab_type)));
3420 }
3421 if (cuda_plan.params().stride_b) {
3422 SE_RETURN_IF_ERROR(SetCublasLtAttr(
3423 preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES,
3424 (uint32)get_alignment_bytes(cuda_plan.params().stride_b,
3425 cuda_plan.params().ab_type)));
3426 }
3427 if (cuda_plan.params().stride_c) {
3428 SE_RETURN_IF_ERROR(SetCublasLtAttr(
3429 preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES,
3430 (uint32)get_alignment_bytes(cuda_plan.params().stride_c,
3431 cuda_plan.params().c_type)));
3432 }
3433 if (cuda_plan.params().stride_c) {
3434 SE_RETURN_IF_ERROR(SetCublasLtAttr(
3435 preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES,
3436 (uint32)get_alignment_bytes(cuda_plan.params().stride_c,
3437 cuda_plan.params().c_type)));
3438 }
3439 return unique_preference;
3440 }
3441
3442 } // namespace
3443
3444 #endif // CUDA_VERSION >= 11000
3445
3446 port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
3447 CUDABlas::CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams &p) {
3448 #if CUDA_VERSION >= 11000
3449 auto cuda_plan = std::make_unique<CUDABlasLtMatmulPlan>();
3450 SE_RETURN_IF_ERROR(cuda_plan->init(p));
3451 return static_cast<std::unique_ptr<blas::IBlasLtMatmulPlan>>(
3452 std::move(cuda_plan));
3453 #else
3454 return port::Status(
3455 port::error::UNIMPLEMENTED,
3456 "CreateBlasLtMatmulPlan is not supported with this version of CUDA");
3457 #endif
3458 }
3459
3460 port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
3461 CUDABlas::GetBlasLtMatmulAlgorithmsInternal(const blas::IBlasLtMatmulPlan *plan,
3462 size_t max_workspace_size,
3463 int max_algorithm_count,
3464 bool for_remainder_batch) {
3465 #if CUDA_VERSION >= 11000
3466 SE_ASSIGN_OR_RETURN(UniqueMatmulPreference preference,
3467 CreateCublasLtMatmulPreference(plan, max_workspace_size));
3468
3469 std::vector<cublasLtMatmulHeuristicResult_t> results(max_algorithm_count);
3470 {
3471 absl::MutexLock lock(&mu_);
3472
3473 CHECK(blasLt_ != nullptr);
3474
3475 gpu::ScopedActivateExecutorContext sac{parent_};
3476
3477 int found_algorithm_count = 0;
3478 const auto &cuda_plan = *static_cast<const CUDABlasLtMatmulPlan *>(plan);
3479 const auto &a_desc =
3480 for_remainder_batch ? cuda_plan.a_remainder_desc() : cuda_plan.a_desc();
3481 const auto &b_desc =
3482 for_remainder_batch ? cuda_plan.b_remainder_desc() : cuda_plan.b_desc();
3483 const auto &c_desc =
3484 for_remainder_batch ? cuda_plan.c_remainder_desc() : cuda_plan.c_desc();
3485 const auto &d_desc =
3486 for_remainder_batch ? cuda_plan.d_remainder_desc() : cuda_plan.d_desc();
3487 cublasStatus_t status = cublasLtMatmulAlgoGetHeuristic(
3488 blasLt_, cuda_plan.op_desc(), a_desc, b_desc, c_desc, d_desc,
3489 preference.get(), max_algorithm_count, results.data(),
3490 &found_algorithm_count);
3491 if (status != CUBLAS_STATUS_SUCCESS) {
3492 return port::Status(
3493 port::error::INTERNAL,
3494 absl::StrCat("cublasLtMatmulAlgoGetHeuristic failed: ",
3495 ToString(status)));
3496 }
3497 results.resize(found_algorithm_count);
3498 }
3499
3500 std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>> out_algorithms;
3501 out_algorithms.reserve(results.size());
3502 for (size_t i = 0; i < results.size(); ++i) {
3503 const auto &result = results[i];
3504 if (result.state != CUBLAS_STATUS_SUCCESS) continue; // Skip failed algos
3505 out_algorithms.emplace_back(std::make_unique<CUDABlasLtMatmulAlgorithm>(
3506 i, result.algo, result.workspaceSize));
3507 }
3508 return out_algorithms;
3509 #else // if CUDA_VERSION < 11000
3510 return port::Status(
3511 port::error::UNIMPLEMENTED,
3512 "GetBlasLtMatmulAlgorithms is not supported with this version of CUDA");
3513 #endif
3514 }
3515
3516 port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
3517 CUDABlas::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
3518 size_t max_workspace_size,
3519 int max_algorithm_count) {
3520 return GetBlasLtMatmulAlgorithmsInternal(plan, max_workspace_size,
3521 max_algorithm_count);
3522 }
3523
3524 #if CUDA_VERSION >= 11000
3525 bool CUDABlas::DoBlasLtMatmulInternal(
3526 Stream *stream, bool err_on_failure, const blas::IBlasLtMatmulPlan *plan,
3527 const HostOrDeviceScalar<void> &alpha, DeviceMemoryBase a,
3528 DeviceMemoryBase b, const HostOrDeviceScalar<void> &beta,
3529 DeviceMemoryBase c, DeviceMemoryBase d, ScratchAllocator *scratch_allocator,
3530 const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias) {
3531 const auto &cuda_plan = *static_cast<const CUDABlasLtMatmulPlan *>(plan);
3532 const auto &cuda_algo =
3533 *static_cast<const CUDABlasLtMatmulAlgorithm *>(algorithm);
3534
3535 if (alpha.data_type() != cuda_plan.scale_type() ||
3536 beta.data_type() != cuda_plan.scale_type()) {
3537 VLOG(2) << "DoBlasLtMatmul returning false because alpha and beta types do "
3538 "not match plan: expected "
3539 << cuda_plan.c_type() << ", got alpha=" << alpha.data_type()
3540 << " beta=" << beta.data_type();
3541 return false;
3542 }
3543 if (alpha.is_pointer() != beta.is_pointer()) {
3544 VLOG(2) << "DoBlasLtMatmul returning false because one of `alpha` "
3545 "and `beta` is a pointer, but the other is not.";
3546 return false;
3547 }
3548 bool is_pointer_mode_host = !alpha.is_pointer();
3549 if ((cuda_plan.params().pointer_mode == blas::PointerMode::kHost) !=
3550 is_pointer_mode_host) {
3551 VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong "
3552 "pointer_mode for the given alpha/beta.";
3553 return false;
3554 }
3555 if ((cuda_plan.params().epilogue == blas::Epilogue::kBias ||
3556 cuda_plan.params().epilogue == blas::Epilogue::kBiasThenReLU) !=
3557 (bias != nullptr)) {
3558 VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong "
3559 "epilogue for the given bias pointer.";
3560 return false;
3561 }
3562 const void *alpha_ptr = alpha.is_pointer() ? alpha.opaque_pointer().opaque()
3563 : alpha.opaque_value();
3564 const void *beta_ptr =
3565 beta.is_pointer() ? beta.opaque_pointer().opaque() : beta.opaque_value();
3566
3567 void *workspace = nullptr;
3568 if (cuda_algo.workspace_size()) {
3569 port::Status allocation_status = AllocateWorkspace(
3570 &workspace, scratch_allocator, cuda_algo.workspace_size());
3571 if (!allocation_status.ok()) {
3572 if (err_on_failure || VLOG_IS_ON(3)) {
3573 LOG(ERROR)
3574 << "Failed to allocate workspace for cublasLtMatmul algo with id: "
3575 << cuda_algo.algo_id() << " requiring "
3576 << cuda_algo.workspace_size() << " bytes of workspace";
3577 }
3578 return false;
3579 }
3580 }
3581
3582 // This is only used when batch_count > kMaxBatchCount.
3583 std::unique_ptr<blas::IBlasLtMatmulAlgorithm> unique_remainder_algo;
3584 if (cuda_plan.remainder_batch_count()) {
3585 // There is no easy way to get the user-specified max workspace size here,
3586 // so we just allow a very small amount and don't worry too much about
3587 // performance because this is only used in rare cases. The same reasoning
3588 // applies to selection of the algorithm.
3589 size_t max_workspace_size = 4 * 1024 * 1024; // 4 MiB
3590 auto status_or_algorithms =
3591 GetBlasLtMatmulAlgorithmsInternal(plan, max_workspace_size,
3592 /* max_algorithm_count = */ 1,
3593 /* for_remainder_batch = */ true);
3594 if (!status_or_algorithms.ok()) {
3595 if (err_on_failure || VLOG_IS_ON(3)) {
3596 LOG(ERROR) << "Failed to get algorithms for blasLt remainder batch.";
3597 }
3598 return false;
3599 }
3600 auto algorithms = status_or_algorithms.ConsumeValueOrDie();
3601 unique_remainder_algo = std::move(algorithms.front());
3602 }
3603
3604 cudaStream_t cuda_stream = CUDAStream(stream);
3605
3606 absl::MutexLock lock(&mu_);
3607
3608 if (bias != nullptr) {
3609 if (!cuda_plan.SetBiasPointer(bias.opaque())) {
3610 VLOG(2) << "DoBlasLtMatmul returning false because setting the bias "
3611 "pointer failed.";
3612 return false;
3613 }
3614 }
3615
3616 CHECK(blasLt_ != nullptr);
3617
3618 gpu::ScopedActivateExecutorContext sac{parent_};
3619
3620 // Plan execution is broken down into repeat calls with capped_batch_count,
3621 // followed by a final call with remainder_batch_count.
3622 // Cases where batch_count <= kMaxBatchCount require only a single call (a
3623 // single loop iteration and no remainder).
3624 int ab_type_size = GetDataTypeSizeBytes(cuda_plan.params().ab_type);
3625 int c_type_size = GetDataTypeSizeBytes(cuda_plan.params().c_type);
3626 const char *a_ptr = static_cast<const char *>(a.opaque());
3627 const char *b_ptr = static_cast<const char *>(b.opaque());
3628 const char *c_ptr = static_cast<const char *>(c.opaque());
3629 char *d_ptr = static_cast<char *>(d.opaque());
3630 int capped_batch_count = cuda_plan.capped_batch_count();
3631 for (int batch = 0;
3632 batch + capped_batch_count <= cuda_plan.params().batch_count;
3633 batch += capped_batch_count) {
3634 cublasStatus_t ret = cublasLtMatmul(
3635 blasLt_, cuda_plan.op_desc(), alpha_ptr, a_ptr, cuda_plan.a_desc(),
3636 b_ptr, cuda_plan.b_desc(), beta_ptr, c_ptr, cuda_plan.c_desc(), d_ptr,
3637 cuda_plan.d_desc(), cuda_algo.algo(), workspace,
3638 cuda_algo.workspace_size(), cuda_stream);
3639 if (ret != CUBLAS_STATUS_SUCCESS) {
3640 if (err_on_failure || VLOG_IS_ON(3)) {
3641 LOG(ERROR) << "failed to run cublasLtMatmul routine: " << ToString(ret);
3642 }
3643 return false;
3644 }
3645 a_ptr += capped_batch_count * cuda_plan.params().stride_a * ab_type_size;
3646 b_ptr += capped_batch_count * cuda_plan.params().stride_b * ab_type_size;
3647 c_ptr += capped_batch_count * cuda_plan.params().stride_c * c_type_size;
3648 d_ptr += capped_batch_count * cuda_plan.params().stride_c * c_type_size;
3649 }
3650 // This is only used when batch_count > kMaxBatchCount.
3651 if (cuda_plan.remainder_batch_count()) {
3652 const auto &remainder_algo =
3653 *static_cast<const CUDABlasLtMatmulAlgorithm *>(
3654 unique_remainder_algo.get());
3655 if (remainder_algo.workspace_size()) {
3656 port::Status allocation_status = AllocateWorkspace(
3657 &workspace, scratch_allocator, remainder_algo.workspace_size());
3658 if (!allocation_status.ok()) {
3659 if (err_on_failure || VLOG_IS_ON(3)) {
3660 LOG(ERROR) << "Failed to allocate workspace for cublasLtMatmul algo "
3661 "with id: "
3662 << remainder_algo.algo_id() << " requiring "
3663 << remainder_algo.workspace_size()
3664 << " bytes of workspace";
3665 }
3666 return false;
3667 }
3668 }
3669 cublasStatus_t ret = cublasLtMatmul(
3670 blasLt_, cuda_plan.op_desc(), alpha_ptr, a_ptr,
3671 cuda_plan.a_remainder_desc(), b_ptr, cuda_plan.b_remainder_desc(),
3672 beta_ptr, c_ptr, cuda_plan.c_remainder_desc(), d_ptr,
3673 cuda_plan.d_remainder_desc(), remainder_algo.algo(), workspace,
3674 remainder_algo.workspace_size(), cuda_stream);
3675 if (ret != CUBLAS_STATUS_SUCCESS) {
3676 if (err_on_failure || VLOG_IS_ON(3)) {
3677 LOG(ERROR) << "failed to run remainder cublasLtMatmul routine: "
3678 << ToString(ret);
3679 }
3680 return false;
3681 }
3682 }
3683 return true;
3684 }
3685 #endif // CUDA_VERSION >= 11000
3686
3687 bool CUDABlas::DoBlasLtMatmul(
3688 Stream *stream, const blas::IBlasLtMatmulPlan *plan,
3689 const HostOrDeviceScalar<void> &alpha, DeviceMemoryBase a,
3690 DeviceMemoryBase b, const HostOrDeviceScalar<void> &beta,
3691 DeviceMemoryBase c, ScratchAllocator *scratch_allocator,
3692 const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias,
3693 blas::ProfileResult *output_profile_result) {
3694 #if CUDA_VERSION >= 11000
3695 const auto &cuda_plan = *static_cast<const CUDABlasLtMatmulPlan *>(plan);
3696 HostOrDeviceScalar<void> alpha_cast = alpha;
3697 HostOrDeviceScalar<void> beta_cast = beta;
3698 if (cuda_plan.c_type() == blas::DataType::kHalf &&
3699 cuda_plan.scale_type() == blas::DataType::kFloat) {
3700 // The given alpha and beta types are F16 (they always match c), but F32*
3701 // computation type requires that they be F32, so we must cast them.
3702 if (alpha.is_pointer() || beta.is_pointer()) {
3703 // We cannot easily convert a pointer to f16 memory to a pointer to f32
3704 // memory from here, so we don't support this for now.
3705 return false;
3706 }
3707 alpha_cast = HostOrDeviceScalar<void>(
3708 static_cast<float>(alpha.value<Eigen::half>()));
3709 beta_cast =
3710 HostOrDeviceScalar<void>(static_cast<float>(beta.value<Eigen::half>()));
3711 }
3712
3713 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
3714 if (output_profile_result) {
3715 timer.reset(new GpuTimer(parent_));
3716 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
3717 return false;
3718 }
3719 }
3720
3721 bool err_on_failure = timer != nullptr;
3722 bool result = DoBlasLtMatmulInternal(stream, err_on_failure, plan, alpha_cast,
3723 a, b, beta_cast, c, c, scratch_allocator,
3724 algorithm, bias);
3725
3726 if (timer && result) {
3727 // GpuTimer will CHECK-fail if we Stop() it while the stream is in an error
3728 // state.
3729 if (!timer->Stop(AsGpuStream(stream))) {
3730 return false;
3731 }
3732 output_profile_result->set_is_valid(true);
3733 output_profile_result->set_algorithm(algorithm->index());
3734 output_profile_result->set_elapsed_time_in_ms(
3735 timer->GetElapsedMilliseconds());
3736 }
3737 return result;
3738 #else // if CUDA_VERSION < 11000
3739 return false;
3740 #endif
3741 }
3742
3743 port::Status CUDABlas::GetVersion(std::string *version) {
3744 absl::MutexLock lock(&mu_);
3745
3746 int v;
3747 auto status = cublasGetVersion(blas_, &v);
3748 if (status != CUBLAS_STATUS_SUCCESS) {
3749 return port::InternalError(ToString(status));
3750 }
3751 *version = std::to_string(v);
3752 return port::Status::OK();
3753 }
3754
3755 } // namespace gpu
3756
initialize_cublas()3757 void initialize_cublas() {
3758 port::Status status =
3759 PluginRegistry::Instance()->RegisterFactory<PluginRegistry::BlasFactory>(
3760 cuda::kCudaPlatformId, gpu::kCuBlasPlugin, "cuBLAS",
3761 [](internal::StreamExecutorInterface *parent) -> blas::BlasSupport * {
3762 gpu::GpuExecutor *cuda_executor =
3763 dynamic_cast<gpu::GpuExecutor *>(parent);
3764 if (cuda_executor == nullptr) {
3765 LOG(ERROR)
3766 << "Attempting to initialize an instance of the cuBLAS "
3767 << "support library with a non-CUDA StreamExecutor";
3768 return nullptr;
3769 }
3770
3771 gpu::CUDABlas *blas = new gpu::CUDABlas(cuda_executor);
3772 if (!blas->Init()) {
3773 // Note: Init() will log a more specific error.
3774 delete blas;
3775 return nullptr;
3776 }
3777 return blas;
3778 });
3779
3780 if (!status.ok()) {
3781 LOG(ERROR) << "Unable to register cuBLAS factory: "
3782 << status.error_message();
3783 }
3784
3785 PluginRegistry::Instance()->SetDefaultFactory(
3786 cuda::kCudaPlatformId, PluginKind::kBlas, gpu::kCuBlasPlugin);
3787 }
3788
3789 } // namespace stream_executor
3790
3791 REGISTER_MODULE_INITIALIZER(register_cublas,
3792 { stream_executor::initialize_cublas(); });
3793