1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "rocm/include/rocblas.h"
17
18 #include "tensorflow/stream_executor/rocm/rocm_blas.h"
19
20 #define EIGEN_USE_GPU
21 #include <assert.h>
22
23 #include <complex>
24
25 #include "absl/strings/str_cat.h"
26 #include "absl/strings/str_format.h"
27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
28 #include "tensorflow/stream_executor/device_memory.h"
29 #include "tensorflow/stream_executor/gpu/gpu_activation.h"
30 #include "tensorflow/stream_executor/gpu/gpu_executor.h"
31 #include "tensorflow/stream_executor/gpu/gpu_helpers.h"
32 #include "tensorflow/stream_executor/gpu/gpu_stream.h"
33 #include "tensorflow/stream_executor/gpu/gpu_timer.h"
34 #include "tensorflow/stream_executor/lib/env.h"
35 #include "tensorflow/stream_executor/lib/initialize.h"
36 #include "tensorflow/stream_executor/lib/status.h"
37 #include "tensorflow/stream_executor/lib/status_macros.h"
38 #include "tensorflow/stream_executor/platform/dso_loader.h"
39 #include "tensorflow/stream_executor/platform/logging.h"
40 #include "tensorflow/stream_executor/platform/port.h"
41 #include "tensorflow/stream_executor/plugin_registry.h"
42 #include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
43 #include "tensorflow/stream_executor/scratch_allocator.h"
44 #include "tensorflow/stream_executor/stream_executor.h"
45
46 namespace stream_executor {
47 namespace gpu {
48
49 PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kRocBlasPlugin);
50
51 namespace wrap {
52
53 #ifdef PLATFORM_GOOGLE
54 #define STREAM_EXECUTOR_ROCBLAS_WRAP(__name) \
55 struct WrapperShim__##__name { \
56 static const char *kName; \
57 template <typename... Args> \
58 rocblas_status operator()(GpuExecutor *parent, Args... args) { \
59 gpu::ScopedActivateExecutorContext sac{parent}; \
60 return ::__name(args...); \
61 } \
62 } __name; \
63 const char *WrapperShim__##__name::kName = #__name;
64
65 #define STREAM_EXECUTOR_ROCBLAS_V2_WRAP(__name) \
66 STREAM_EXECUTOR_ROCBLAS_WRAP(__name)
67
68 #else
69
70 #define STREAM_EXECUTOR_ROCBLAS_WRAP(__name) \
71 struct DynLoadShim__##__name { \
72 static const char *kName; \
73 using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \
74 static void *GetDsoHandle() { \
75 auto s = internal::CachedDsoLoader::GetRocblasDsoHandle(); \
76 return s.ValueOrDie(); \
77 } \
78 static FuncPtrT LoadOrDie() { \
79 void *f; \
80 auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \
81 kName, &f); \
82 CHECK(s.ok()) << "could not find " << kName \
83 << " in rocblas DSO; dlerror: " << s.error_message(); \
84 return reinterpret_cast<FuncPtrT>(f); \
85 } \
86 static FuncPtrT DynLoad() { \
87 static FuncPtrT f = LoadOrDie(); \
88 return f; \
89 } \
90 template <typename... Args> \
91 rocblas_status operator()(GpuExecutor *parent, Args... args) { \
92 gpu::ScopedActivateExecutorContext sac{parent}; \
93 return DynLoad()(args...); \
94 } \
95 } __name; \
96 const char *DynLoadShim__##__name::kName = #__name;
97
98 #define STREAM_EXECUTOR_ROCBLAS_V2_WRAP(__name) \
99 STREAM_EXECUTOR_ROCBLAS_WRAP(__name)
100
101 #endif
102
103 // clang-format off
104 #define ROCBLAS_BLAS_ROUTINE_EACH(__macro) \
105 __macro(rocblas_snrm2) \
106 __macro(rocblas_dnrm2) \
107 __macro(rocblas_scnrm2) \
108 __macro(rocblas_dznrm2) \
109 __macro(rocblas_sdot) \
110 __macro(rocblas_ddot) \
111 __macro(rocblas_cdotu) \
112 __macro(rocblas_cdotc) \
113 __macro(rocblas_zdotu) \
114 __macro(rocblas_zdotc) \
115 __macro(rocblas_sscal) \
116 __macro(rocblas_dscal) \
117 __macro(rocblas_cscal) \
118 __macro(rocblas_csscal) \
119 __macro(rocblas_zscal) \
120 __macro(rocblas_zdscal) \
121 __macro(rocblas_saxpy) \
122 __macro(rocblas_daxpy) \
123 __macro(rocblas_caxpy) \
124 __macro(rocblas_zaxpy) \
125 __macro(rocblas_scopy) \
126 __macro(rocblas_dcopy) \
127 __macro(rocblas_ccopy) \
128 __macro(rocblas_zcopy) \
129 __macro(rocblas_sswap) \
130 __macro(rocblas_dswap) \
131 __macro(rocblas_cswap) \
132 __macro(rocblas_zswap) \
133 __macro(rocblas_isamax) \
134 __macro(rocblas_idamax) \
135 __macro(rocblas_icamax) \
136 __macro(rocblas_izamax) \
137 __macro(rocblas_isamin) \
138 __macro(rocblas_idamin) \
139 __macro(rocblas_icamin) \
140 __macro(rocblas_izamin) \
141 __macro(rocblas_sasum) \
142 __macro(rocblas_dasum) \
143 __macro(rocblas_scasum) \
144 __macro(rocblas_dzasum) \
145 __macro(rocblas_srot) \
146 __macro(rocblas_drot) \
147 __macro(rocblas_crot) \
148 __macro(rocblas_csrot) \
149 __macro(rocblas_zrot) \
150 __macro(rocblas_zdrot) \
151 __macro(rocblas_srotg) \
152 __macro(rocblas_drotg) \
153 __macro(rocblas_crotg) \
154 __macro(rocblas_zrotg) \
155 __macro(rocblas_srotm) \
156 __macro(rocblas_drotm) \
157 __macro(rocblas_srotmg) \
158 __macro(rocblas_drotmg) \
159 __macro(rocblas_sgemv) \
160 __macro(rocblas_dgemv) \
161 __macro(rocblas_cgemv) \
162 __macro(rocblas_zgemv) \
163 __macro(rocblas_sgbmv) \
164 __macro(rocblas_dgbmv) \
165 __macro(rocblas_cgbmv) \
166 __macro(rocblas_zgbmv) \
167 __macro(rocblas_strmv) \
168 __macro(rocblas_dtrmv) \
169 __macro(rocblas_ctrmv) \
170 __macro(rocblas_ztrmv) \
171 __macro(rocblas_stbmv) \
172 __macro(rocblas_dtbmv) \
173 __macro(rocblas_ctbmv) \
174 __macro(rocblas_ztbmv) \
175 __macro(rocblas_stpmv) \
176 __macro(rocblas_dtpmv) \
177 __macro(rocblas_ctpmv) \
178 __macro(rocblas_ztpmv) \
179 __macro(rocblas_strsv) \
180 __macro(rocblas_dtrsv) \
181 __macro(rocblas_ctrsv) \
182 __macro(rocblas_ztrsv) \
183 __macro(rocblas_stpsv) \
184 __macro(rocblas_dtpsv) \
185 __macro(rocblas_ctpsv) \
186 __macro(rocblas_ztpsv) \
187 __macro(rocblas_stbsv) \
188 __macro(rocblas_dtbsv) \
189 __macro(rocblas_ctbsv) \
190 __macro(rocblas_ztbsv) \
191 __macro(rocblas_ssymv) \
192 __macro(rocblas_dsymv) \
193 /* __macro(rocblas_csymv) \
194 __macro(rocblas_zsymv) */ \
195 __macro(rocblas_chemv) \
196 __macro(rocblas_zhemv) \
197 __macro(rocblas_ssbmv) \
198 __macro(rocblas_dsbmv) \
199 __macro(rocblas_chbmv) \
200 __macro(rocblas_zhbmv) \
201 __macro(rocblas_sspmv) \
202 __macro(rocblas_dspmv) \
203 __macro(rocblas_chpmv) \
204 __macro(rocblas_zhpmv) \
205 __macro(rocblas_sger) \
206 __macro(rocblas_dger) \
207 __macro(rocblas_cgeru) \
208 __macro(rocblas_cgerc) \
209 __macro(rocblas_zgeru) \
210 __macro(rocblas_zgerc) \
211 __macro(rocblas_ssyr) \
212 __macro(rocblas_dsyr) \
213 /*__macro(rocblas_csyr) \
214 __macro(rocblas_zsyr) */ \
215 __macro(rocblas_cher) \
216 __macro(rocblas_zher) \
217 __macro(rocblas_sspr) \
218 __macro(rocblas_dspr) \
219 __macro(rocblas_chpr) \
220 __macro(rocblas_zhpr) \
221 __macro(rocblas_ssyr2) \
222 __macro(rocblas_dsyr2) \
223 /* __macro(rocblas_csyr2) \
224 __macro(rocblas_zsyr2) */ \
225 __macro(rocblas_cher2) \
226 __macro(rocblas_zher2) \
227 __macro(rocblas_sspr2) \
228 __macro(rocblas_dspr2) \
229 __macro(rocblas_chpr2) \
230 __macro(rocblas_zhpr2) \
231 __macro(rocblas_sgemm) \
232 __macro(rocblas_dgemm) \
233 __macro(rocblas_hgemm) \
234 __macro(rocblas_cgemm) \
235 __macro(rocblas_zgemm) \
236 __macro(rocblas_ssyrk) \
237 __macro(rocblas_dsyrk) \
238 __macro(rocblas_csyrk) \
239 __macro(rocblas_zsyrk) \
240 __macro(rocblas_cherk) \
241 __macro(rocblas_zherk) \
242 __macro(rocblas_ssyr2k) \
243 __macro(rocblas_dsyr2k) \
244 __macro(rocblas_csyr2k) \
245 __macro(rocblas_zsyr2k) \
246 __macro(rocblas_cher2k) \
247 __macro(rocblas_zher2k) \
248 /* __macro(rocblas_ssyrkx) \
249 __macro(rocblas_dsyrkx) \
250 __macro(rocblas_csyrkx) \
251 __macro(rocblas_zsyrkx) \
252 __macro(rocblas_cherkx) \
253 __macro(rocblas_zherkx) */ \
254 __macro(rocblas_ssymm) \
255 __macro(rocblas_dsymm) \
256 __macro(rocblas_csymm) \
257 __macro(rocblas_zsymm) \
258 __macro(rocblas_chemm) \
259 __macro(rocblas_zhemm) \
260 __macro(rocblas_strsm) \
261 __macro(rocblas_dtrsm) \
262 __macro(rocblas_ctrsm) \
263 __macro(rocblas_ztrsm) \
264 __macro(rocblas_strmm) \
265 __macro(rocblas_dtrmm) \
266 __macro(rocblas_ctrmm) \
267 __macro(rocblas_ztrmm) \
268 __macro(rocblas_sgeam) \
269 __macro(rocblas_dgeam) \
270 /*__macro(rocblas_cgeam) \
271 __macro(rocblas_zgeam) \
272 __macro(rocblas_sdgmm) \
273 __macro(rocblas_ddgmm) \
274 __macro(rocblas_cdgmm) \
275 __macro(rocblas_zdgmm) */
276 // clang-format on
277
278 STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_create_handle)
279 STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_destroy_handle)
280 STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_set_stream)
281 // STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_set_pointer_mode)
282 // STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_get_pointer_mode)
283 // STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_sgemm_batched)
284 STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_hgemm_strided_batched)
285 STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_sgemm_strided_batched)
286 // STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_dgemm_batched)
287 STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_dgemm_strided_batched)
288 STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_cgemm_strided_batched)
289 STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_zgemm_strided_batched)
290 // STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_cgemm_batched)
291 // STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_zgemm_batched)
292 ROCBLAS_BLAS_ROUTINE_EACH(STREAM_EXECUTOR_ROCBLAS_V2_WRAP)
293
294 } // namespace wrap
295
296 template <class T>
complex_cast(const DeviceMemory<T> & a)297 const typename RocBlasTypeConversionHelper<T>::mapped_type *complex_cast(
298 const DeviceMemory<T> &a) {
299 return reinterpret_cast<
300 const typename RocBlasTypeConversionHelper<T>::mapped_type *>(
301 GpuMemory(a));
302 }
303 template <class T>
complex_cast(const T & a)304 const typename RocBlasTypeConversionHelper<T>::mapped_type *complex_cast(
305 const T &a) {
306 return reinterpret_cast<
307 const typename RocBlasTypeConversionHelper<T>::mapped_type *>(&a);
308 }
309 template <class T>
complex_cast(DeviceMemory<T> * a)310 typename RocBlasTypeConversionHelper<T>::mapped_type *complex_cast(
311 DeviceMemory<T> *a) {
312 return reinterpret_cast<
313 typename RocBlasTypeConversionHelper<T>::mapped_type *>(
314 GpuMemoryMutable(a));
315 }
316
blas_log(const char * c)317 static void blas_log(const char *c) {}
318
ToString(rocblas_status status)319 static string ToString(rocblas_status status) {
320 switch (status) {
321 case rocblas_status_success:
322 return "rocblas_status_success";
323 case rocblas_status_invalid_handle:
324 return "rocblas_status_invalid_handle";
325 case rocblas_status_not_implemented:
326 return "rocblas_status_not_implemented";
327 case rocblas_status_invalid_pointer:
328 return "rocblas_status_invalid_pointer";
329 case rocblas_status_invalid_size:
330 return "rocblas_status_invalid_size";
331 case rocblas_status_memory_error:
332 return "rocblas_status_memory_error";
333 case rocblas_status_internal_error:
334 return "rocblas_status_internal_error";
335 default:
336 return absl::StrCat("<invalid rocBLAS status: ", status, ">");
337 }
338 }
339
Init()340 bool ROCMBlas::Init() {
341 rocblas_status ret = wrap::rocblas_create_handle(parent_, &blas_);
342 if (ret != rocblas_status_success) {
343 LOG(ERROR) << "failed to create rocBLAS handle: " << ToString(ret);
344 return false;
345 }
346
347 return true;
348 }
349
ROCMBlas(gpu::GpuExecutor * parent)350 ROCMBlas::ROCMBlas(gpu::GpuExecutor *parent)
351 : parent_(CHECK_NOTNULL(parent)), blas_(nullptr) {}
352
~ROCMBlas()353 ROCMBlas::~ROCMBlas() {
354 if (blas_ != nullptr) {
355 wrap::rocblas_destroy_handle(parent_, blas_);
356 }
357 }
358
SetStream(Stream * stream)359 bool ROCMBlas::SetStream(Stream *stream) {
360 CHECK(stream != nullptr);
361 CHECK(AsGpuStreamValue(stream) != nullptr);
362 CHECK(blas_ != nullptr);
363 rocblas_status ret =
364 wrap::rocblas_set_stream(parent_, blas_, AsGpuStreamValue(stream));
365 if (ret != rocblas_status_success) {
366 LOG(ERROR) << "failed to set stream for rocBLAS calls: " << ToString(ret);
367 return false;
368 }
369
370 return true;
371 }
372
373 namespace {
374
375 // Helper functions transforming blas arguments into rocBLAS arguments.
376
ROCMBlasTranspose(blas::Transpose trans)377 rocblas_operation ROCMBlasTranspose(blas::Transpose trans) {
378 switch (trans) {
379 case blas::Transpose::kNoTranspose:
380 return rocblas_operation_none;
381 case blas::Transpose::kTranspose:
382 return rocblas_operation_transpose;
383 case blas::Transpose::kConjugateTranspose:
384 return rocblas_operation_conjugate_transpose;
385 default:
386 LOG(FATAL) << "Invalid value of blas::Transpose.";
387 }
388 }
389
ROCMBlasUpperLower(blas::UpperLower uplo)390 rocblas_fill ROCMBlasUpperLower(blas::UpperLower uplo) {
391 switch (uplo) {
392 case blas::UpperLower::kUpper:
393 return rocblas_fill_upper;
394 case blas::UpperLower::kLower:
395 return rocblas_fill_lower;
396 default:
397 LOG(FATAL) << "Invalid value of blas::UpperLower.";
398 }
399 }
400
ROCMBlasDiagonal(blas::Diagonal diag)401 rocblas_diagonal ROCMBlasDiagonal(blas::Diagonal diag) {
402 switch (diag) {
403 case blas::Diagonal::kUnit:
404 return rocblas_diagonal_unit;
405 case blas::Diagonal::kNonUnit:
406 return rocblas_diagonal_non_unit;
407 default:
408 LOG(FATAL) << "Invalid value of blas::Diagonal.";
409 }
410 }
411
ROCMBlasSide(blas::Side side)412 rocblas_side ROCMBlasSide(blas::Side side) {
413 switch (side) {
414 case blas::Side::kLeft:
415 return rocblas_side_left;
416 case blas::Side::kRight:
417 return rocblas_side_right;
418 default:
419 LOG(FATAL) << "Invalid value of blas::Side.";
420 }
421 }
422
423 } // namespace
424
425 template <typename FuncT, typename... Args>
DoBlasInternalImpl(FuncT rocblas_func,Stream * stream,bool pointer_mode_host,bool err_on_failure,Args...args)426 bool ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream,
427 bool pointer_mode_host, bool err_on_failure,
428 Args... args) {
429 absl::MutexLock lock{&mu_};
430
431 CHECK(blas_ != nullptr);
432 if (!SetStream(stream)) {
433 return false;
434 }
435
436 rocblas_status ret = rocblas_func(parent_, blas_, args...);
437 if (err_on_failure && ret != rocblas_status_success) {
438 LOG(ERROR) << "failed to run ROCBLAS routine " << rocblas_func.kName << ": "
439 << ToString(ret);
440 }
441 return ret == rocblas_status_success;
442 }
443
DoBlasAsum(Stream * stream,uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * result)444 bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count,
445 const DeviceMemory<float> &x, int incx,
446 DeviceMemory<float> *result) {
447 return DoBlasInternal(wrap::rocblas_sasum, stream,
448 /* pointer_mode_host = */ false, elem_count,
449 GpuMemory(x), incx, GpuMemoryMutable(result));
450 }
451
DoBlasAsum(Stream * stream,uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * result)452 bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count,
453 const DeviceMemory<double> &x, int incx,
454 DeviceMemory<double> *result) {
455 return DoBlasInternal(wrap::rocblas_dasum, stream,
456 /* pointer_mode_host = */ false, elem_count,
457 GpuMemory(x), incx, GpuMemoryMutable(result));
458 }
459
DoBlasAsum(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<float> * result)460 bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count,
461 const DeviceMemory<std::complex<float>> &x, int incx,
462 DeviceMemory<float> *result) {
463 return DoBlasInternal(wrap::rocblas_scasum, stream,
464 /* pointer_mode_host = */ false, elem_count,
465 complex_cast(x), incx, GpuMemoryMutable(result));
466 }
467
DoBlasAsum(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<double> * result)468 bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count,
469 const DeviceMemory<std::complex<double>> &x, int incx,
470 DeviceMemory<double> *result) {
471 return DoBlasInternal(wrap::rocblas_dzasum, stream,
472 /* pointer_mode_host = */ false, elem_count,
473 complex_cast(x), incx, GpuMemoryMutable(result));
474 }
475
DoBlasAxpy(Stream * stream,uint64 elem_count,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * y,int incy)476 bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha,
477 const DeviceMemory<float> &x, int incx,
478 DeviceMemory<float> *y, int incy) {
479 blas_log("DoBlasAxpy");
480 return DoBlasInternal(wrap::rocblas_saxpy, stream,
481 /* pointer_mode_host = */ true, elem_count, &alpha,
482 GpuMemory(x), incx, GpuMemoryMutable(y), incy);
483 }
484
DoBlasAxpy(Stream * stream,uint64 elem_count,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * y,int incy)485 bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha,
486 const DeviceMemory<double> &x, int incx,
487 DeviceMemory<double> *y, int incy) {
488 blas_log("DoBlasAxpy");
489 return DoBlasInternal(wrap::rocblas_daxpy, stream,
490 /* pointer_mode_host = */ true, elem_count, &alpha,
491 GpuMemory(x), incx, GpuMemoryMutable(y), incy);
492 }
493
DoBlasAxpy(Stream * stream,uint64 elem_count,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * y,int incy)494 bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count,
495 std::complex<float> alpha,
496 const DeviceMemory<std::complex<float>> &x, int incx,
497 DeviceMemory<std::complex<float>> *y, int incy) {
498 return DoBlasInternal(
499 wrap::rocblas_caxpy, stream, /* pointer_mode_host = */ true, elem_count,
500 complex_cast(alpha), complex_cast(x), incx, complex_cast(y), incy);
501 }
502
DoBlasAxpy(Stream * stream,uint64 elem_count,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * y,int incy)503 bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count,
504 std::complex<double> alpha,
505 const DeviceMemory<std::complex<double>> &x, int incx,
506 DeviceMemory<std::complex<double>> *y, int incy) {
507 return DoBlasInternal(
508 wrap::rocblas_zaxpy, stream, /* pointer_mode_host = */ true, elem_count,
509 complex_cast(alpha), complex_cast(x), incx, complex_cast(y), incy);
510 }
511
DoBlasCopy(Stream * stream,uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * y,int incy)512 bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count,
513 const DeviceMemory<float> &x, int incx,
514 DeviceMemory<float> *y, int incy) {
515 return DoBlasInternal(wrap::rocblas_scopy, stream,
516 /* pointer_mode_host = */ true, elem_count,
517 GpuMemory(x), incx, GpuMemoryMutable(y), incy);
518 }
519
DoBlasCopy(Stream * stream,uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * y,int incy)520 bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count,
521 const DeviceMemory<double> &x, int incx,
522 DeviceMemory<double> *y, int incy) {
523 return DoBlasInternal(wrap::rocblas_dcopy, stream,
524 /* pointer_mode_host = */ true, elem_count,
525 GpuMemory(x), incx, GpuMemoryMutable(y), incy);
526 }
527
DoBlasCopy(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * y,int incy)528 bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count,
529 const DeviceMemory<std::complex<float>> &x, int incx,
530 DeviceMemory<std::complex<float>> *y, int incy) {
531 return DoBlasInternal(wrap::rocblas_ccopy, stream,
532 /* pointer_mode_host = */ true, elem_count,
533 complex_cast(x), incx, complex_cast(y), incy);
534 }
535
DoBlasCopy(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * y,int incy)536 bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count,
537 const DeviceMemory<std::complex<double>> &x, int incx,
538 DeviceMemory<std::complex<double>> *y, int incy) {
539 return DoBlasInternal(wrap::rocblas_zcopy, stream,
540 /* pointer_mode_host = */ true, elem_count,
541 complex_cast(x), incx, complex_cast(y), incy);
542 }
543
DoBlasDot(Stream * stream,uint64 elem_count,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * result)544 bool ROCMBlas::DoBlasDot(Stream *stream, uint64 elem_count,
545 const DeviceMemory<float> &x, int incx,
546 const DeviceMemory<float> &y, int incy,
547 DeviceMemory<float> *result) {
548 blas_log("DoBlasDot");
549 return DoBlasInternal(
550 wrap::rocblas_sdot, stream, /* pointer_mode_host = */ false, elem_count,
551 GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(result));
552 }
553
DoBlasDot(Stream * stream,uint64 elem_count,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * result)554 bool ROCMBlas::DoBlasDot(Stream *stream, uint64 elem_count,
555 const DeviceMemory<double> &x, int incx,
556 const DeviceMemory<double> &y, int incy,
557 DeviceMemory<double> *result) {
558 blas_log("DoBlasDot");
559 return DoBlasInternal(
560 wrap::rocblas_ddot, stream, /* pointer_mode_host = */ false, elem_count,
561 GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(result));
562 }
563
DoBlasDotc(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * result)564 bool ROCMBlas::DoBlasDotc(Stream *stream, uint64 elem_count,
565 const DeviceMemory<std::complex<float>> &x, int incx,
566 const DeviceMemory<std::complex<float>> &y, int incy,
567 DeviceMemory<std::complex<float>> *result) {
568 return DoBlasInternal(
569 wrap::rocblas_cdotc, stream, /* pointer_mode_host = */ false, elem_count,
570 complex_cast(x), incx, complex_cast(y), incy, complex_cast(result));
571 }
572
DoBlasDotc(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * result)573 bool ROCMBlas::DoBlasDotc(Stream *stream, uint64 elem_count,
574 const DeviceMemory<std::complex<double>> &x, int incx,
575 const DeviceMemory<std::complex<double>> &y, int incy,
576 DeviceMemory<std::complex<double>> *result) {
577 return DoBlasInternal(
578 wrap::rocblas_zdotc, stream, /* pointer_mode_host = */ false, elem_count,
579 complex_cast(x), incx, complex_cast(y), incy, complex_cast(result));
580 }
581
DoBlasDotu(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * result)582 bool ROCMBlas::DoBlasDotu(Stream *stream, uint64 elem_count,
583 const DeviceMemory<std::complex<float>> &x, int incx,
584 const DeviceMemory<std::complex<float>> &y, int incy,
585 DeviceMemory<std::complex<float>> *result) {
586 return DoBlasInternal(
587 wrap::rocblas_cdotu, stream, /* pointer_mode_host = */ false, elem_count,
588 complex_cast(x), incx, complex_cast(y), incy, complex_cast(result));
589 }
590
DoBlasDotu(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * result)591 bool ROCMBlas::DoBlasDotu(Stream *stream, uint64 elem_count,
592 const DeviceMemory<std::complex<double>> &x, int incx,
593 const DeviceMemory<std::complex<double>> &y, int incy,
594 DeviceMemory<std::complex<double>> *result) {
595 return DoBlasInternal(
596 wrap::rocblas_zdotu, stream, /* pointer_mode_host = */ false, elem_count,
597 complex_cast(x), incx, complex_cast(y), incy, complex_cast(result));
598 }
599
DoBlasNrm2(Stream * stream,uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * result)600 bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
601 const DeviceMemory<float> &x, int incx,
602 DeviceMemory<float> *result) {
603 return DoBlasInternal(wrap::rocblas_snrm2, stream,
604 /* pointer_mode_host = */ false, elem_count,
605 GpuMemory(x), incx, GpuMemoryMutable(result));
606 }
607
DoBlasNrm2(Stream * stream,uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * result)608 bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
609 const DeviceMemory<double> &x, int incx,
610 DeviceMemory<double> *result) {
611 return DoBlasInternal(wrap::rocblas_dnrm2, stream,
612 /* pointer_mode_host = */ false, elem_count,
613 GpuMemory(x), incx, GpuMemoryMutable(result));
614 }
615
DoBlasNrm2(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<float> * result)616 bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
617 const DeviceMemory<std::complex<float>> &x, int incx,
618 DeviceMemory<float> *result) {
619 return DoBlasInternal(wrap::rocblas_scnrm2, stream,
620 /* pointer_mode_host = */ false, elem_count,
621 complex_cast(x), incx, GpuMemoryMutable(result));
622 }
623
DoBlasNrm2(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<double> * result)624 bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
625 const DeviceMemory<std::complex<double>> &x, int incx,
626 DeviceMemory<double> *result) {
627 return DoBlasInternal(wrap::rocblas_dznrm2, stream,
628 /* pointer_mode_host = */ false, elem_count,
629 complex_cast(x), incx, GpuMemoryMutable(result));
630 }
631
DoBlasRot(Stream * stream,uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy,float c,float s)632 bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count,
633 DeviceMemory<float> *x, int incx,
634 DeviceMemory<float> *y, int incy, float c, float s) {
635 return DoBlasInternal(
636 wrap::rocblas_srot, stream, /* pointer_mode_host = */ true, elem_count,
637 GpuMemoryMutable(x), incx, GpuMemoryMutable(y), incy, &c, &s);
638 }
639
DoBlasRot(Stream * stream,uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy,double c,double s)640 bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count,
641 DeviceMemory<double> *x, int incx,
642 DeviceMemory<double> *y, int incy, double c,
643 double s) {
644 return DoBlasInternal(
645 wrap::rocblas_drot, stream, /* pointer_mode_host = */ true, elem_count,
646 GpuMemoryMutable(x), incx, GpuMemoryMutable(y), incy, &c, &s);
647 }
648
DoBlasRot(Stream * stream,uint64 elem_count,DeviceMemory<std::complex<float>> * x,int incx,DeviceMemory<std::complex<float>> * y,int incy,float c,float s)649 bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count,
650 DeviceMemory<std::complex<float>> *x, int incx,
651 DeviceMemory<std::complex<float>> *y, int incy,
652 float c, float s) {
653 return DoBlasInternal(wrap::rocblas_csrot, stream,
654 /* pointer_mode_host = */ true, elem_count,
655 complex_cast(x), incx, complex_cast(y), incy, &c, &s);
656 }
657
DoBlasRot(Stream * stream,uint64 elem_count,DeviceMemory<std::complex<double>> * x,int incx,DeviceMemory<std::complex<double>> * y,int incy,double c,double s)658 bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count,
659 DeviceMemory<std::complex<double>> *x, int incx,
660 DeviceMemory<std::complex<double>> *y, int incy,
661 double c, double s) {
662 return DoBlasInternal(wrap::rocblas_zdrot, stream,
663 /* pointer_mode_host = */ true, elem_count,
664 complex_cast(x), incx, complex_cast(y), incy, &c, &s);
665 }
666
DoBlasRotg(Stream * stream,DeviceMemory<float> * a,DeviceMemory<float> * b,DeviceMemory<float> * c,DeviceMemory<float> * s)667 bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory<float> *a,
668 DeviceMemory<float> *b, DeviceMemory<float> *c,
669 DeviceMemory<float> *s) {
670 return DoBlasInternal(wrap::rocblas_srotg, stream,
671 /* pointer_mode_host = */ false, GpuMemoryMutable(a),
672 GpuMemoryMutable(b), GpuMemoryMutable(c),
673 GpuMemoryMutable(s));
674 }
675
DoBlasRotg(Stream * stream,DeviceMemory<double> * a,DeviceMemory<double> * b,DeviceMemory<double> * c,DeviceMemory<double> * s)676 bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory<double> *a,
677 DeviceMemory<double> *b, DeviceMemory<double> *c,
678 DeviceMemory<double> *s) {
679 return DoBlasInternal(wrap::rocblas_drotg, stream,
680 /* pointer_mode_host = */ false, GpuMemoryMutable(a),
681 GpuMemoryMutable(b), GpuMemoryMutable(c),
682 GpuMemoryMutable(s));
683 }
684
DoBlasRotg(Stream * stream,DeviceMemory<std::complex<float>> * a,DeviceMemory<std::complex<float>> * b,DeviceMemory<float> * c,DeviceMemory<std::complex<float>> * s)685 bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory<std::complex<float>> *a,
686 DeviceMemory<std::complex<float>> *b,
687 DeviceMemory<float> *c,
688 DeviceMemory<std::complex<float>> *s) {
689 return DoBlasInternal(wrap::rocblas_crotg, stream,
690 /* pointer_mode_host = */ false, complex_cast(a),
691 complex_cast(b), GpuMemoryMutable(c), complex_cast(s));
692 }
693
DoBlasRotg(Stream * stream,DeviceMemory<std::complex<double>> * a,DeviceMemory<std::complex<double>> * b,DeviceMemory<double> * c,DeviceMemory<std::complex<double>> * s)694 bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory<std::complex<double>> *a,
695 DeviceMemory<std::complex<double>> *b,
696 DeviceMemory<double> *c,
697 DeviceMemory<std::complex<double>> *s) {
698 return DoBlasInternal(wrap::rocblas_zrotg, stream,
699 /* pointer_mode_host = */ false, complex_cast(a),
700 complex_cast(b), GpuMemoryMutable(c), complex_cast(s));
701 }
702
DoBlasRotm(Stream * stream,uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy,const DeviceMemory<float> & param)703 bool ROCMBlas::DoBlasRotm(Stream *stream, uint64 elem_count,
704 DeviceMemory<float> *x, int incx,
705 DeviceMemory<float> *y, int incy,
706 const DeviceMemory<float> ¶m) {
707 return DoBlasInternal(
708 wrap::rocblas_srotm, stream, /* pointer_mode_host = */ false, elem_count,
709 GpuMemoryMutable(x), incx, GpuMemoryMutable(y), incy, GpuMemory(param));
710 }
711
DoBlasRotm(Stream * stream,uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy,const DeviceMemory<double> & param)712 bool ROCMBlas::DoBlasRotm(Stream *stream, uint64 elem_count,
713 DeviceMemory<double> *x, int incx,
714 DeviceMemory<double> *y, int incy,
715 const DeviceMemory<double> ¶m) {
716 return DoBlasInternal(
717 wrap::rocblas_drotm, stream, /* pointer_mode_host = */ false, elem_count,
718 GpuMemoryMutable(x), incx, GpuMemoryMutable(y), incy, GpuMemory(param));
719 }
720
DoBlasRotmg(Stream * stream,DeviceMemory<float> * d1,DeviceMemory<float> * d2,DeviceMemory<float> * x1,const DeviceMemory<float> & y1,DeviceMemory<float> * param)721 bool ROCMBlas::DoBlasRotmg(Stream *stream, DeviceMemory<float> *d1,
722 DeviceMemory<float> *d2, DeviceMemory<float> *x1,
723 const DeviceMemory<float> &y1,
724 DeviceMemory<float> *param) {
725 return DoBlasInternal(wrap::rocblas_srotmg, stream,
726 /* pointer_mode_host = */ false, GpuMemoryMutable(d1),
727 GpuMemoryMutable(d2), GpuMemoryMutable(x1),
728 GpuMemory(y1), GpuMemoryMutable(param));
729 }
730
DoBlasRotmg(Stream * stream,DeviceMemory<double> * d1,DeviceMemory<double> * d2,DeviceMemory<double> * x1,const DeviceMemory<double> & y1,DeviceMemory<double> * param)731 bool ROCMBlas::DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1,
732 DeviceMemory<double> *d2, DeviceMemory<double> *x1,
733 const DeviceMemory<double> &y1,
734 DeviceMemory<double> *param) {
735 return DoBlasInternal(wrap::rocblas_drotmg, stream,
736 /* pointer_mode_host = */ false, GpuMemoryMutable(d1),
737 GpuMemoryMutable(d2), GpuMemoryMutable(x1),
738 GpuMemory(y1), GpuMemoryMutable(param));
739 }
740
DoBlasScal(Stream * stream,uint64 elem_count,float alpha,DeviceMemory<float> * x,int incx)741 bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
742 DeviceMemory<float> *x, int incx) {
743 blas_log("DoBlasScal<float>");
744 return DoBlasInternal(wrap::rocblas_sscal, stream,
745 /* pointer_mode_host = */ true, elem_count, &alpha,
746 GpuMemoryMutable(x), incx);
747 }
748
DoBlasScal(Stream * stream,uint64 elem_count,double alpha,DeviceMemory<double> * x,int incx)749 bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
750 DeviceMemory<double> *x, int incx) {
751 return DoBlasInternal(wrap::rocblas_dscal, stream,
752 /* pointer_mode_host = */ true, elem_count, &alpha,
753 GpuMemoryMutable(x), incx);
754 }
755
DoBlasScal(Stream * stream,uint64 elem_count,float alpha,DeviceMemory<std::complex<float>> * x,int incx)756 bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
757 DeviceMemory<std::complex<float>> *x, int incx) {
758 return DoBlasInternal(wrap::rocblas_csscal, stream,
759 /* pointer_mode_host = */ true, elem_count, &alpha,
760 complex_cast(x), incx);
761 }
762
DoBlasScal(Stream * stream,uint64 elem_count,double alpha,DeviceMemory<std::complex<double>> * x,int incx)763 bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
764 DeviceMemory<std::complex<double>> *x, int incx) {
765 return DoBlasInternal(wrap::rocblas_zdscal, stream,
766 /* pointer_mode_host = */ true, elem_count, &alpha,
767 complex_cast(x), incx);
768 }
769
DoBlasScal(Stream * stream,uint64 elem_count,std::complex<float> alpha,DeviceMemory<std::complex<float>> * x,int incx)770 bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count,
771 std::complex<float> alpha,
772 DeviceMemory<std::complex<float>> *x, int incx) {
773 return DoBlasInternal(wrap::rocblas_cscal, stream,
774 /* pointer_mode_host = */ true, elem_count,
775 complex_cast(alpha), complex_cast(x), incx);
776 }
777
DoBlasScal(Stream * stream,uint64 elem_count,std::complex<double> alpha,DeviceMemory<std::complex<double>> * x,int incx)778 bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count,
779 std::complex<double> alpha,
780 DeviceMemory<std::complex<double>> *x, int incx) {
781 return DoBlasInternal(wrap::rocblas_zscal, stream,
782 /* pointer_mode_host = */ true, elem_count,
783 complex_cast(alpha), complex_cast(x), incx);
784 }
785
DoBlasSwap(Stream * stream,uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy)786 bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count,
787 DeviceMemory<float> *x, int incx,
788 DeviceMemory<float> *y, int incy) {
789 return DoBlasInternal(wrap::rocblas_sswap, stream,
790 /* pointer_mode_host = */ true, elem_count,
791 GpuMemoryMutable(x), incx, GpuMemoryMutable(y), incy);
792 }
793
DoBlasSwap(Stream * stream,uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy)794 bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count,
795 DeviceMemory<double> *x, int incx,
796 DeviceMemory<double> *y, int incy) {
797 return DoBlasInternal(wrap::rocblas_dswap, stream,
798 /* pointer_mode_host = */ true, elem_count,
799 GpuMemoryMutable(x), incx, GpuMemoryMutable(y), incy);
800 }
801
DoBlasSwap(Stream * stream,uint64 elem_count,DeviceMemory<std::complex<float>> * x,int incx,DeviceMemory<std::complex<float>> * y,int incy)802 bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count,
803 DeviceMemory<std::complex<float>> *x, int incx,
804 DeviceMemory<std::complex<float>> *y, int incy) {
805 return DoBlasInternal(wrap::rocblas_cswap, stream,
806 /* pointer_mode_host = */ true, elem_count,
807 complex_cast(x), incx, complex_cast(y), incy);
808 }
809
DoBlasSwap(Stream * stream,uint64 elem_count,DeviceMemory<std::complex<double>> * x,int incx,DeviceMemory<std::complex<double>> * y,int incy)810 bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count,
811 DeviceMemory<std::complex<double>> *x, int incx,
812 DeviceMemory<std::complex<double>> *y, int incy) {
813 return DoBlasInternal(wrap::rocblas_zswap, stream,
814 /* pointer_mode_host = */ true, elem_count,
815 complex_cast(x), incx, complex_cast(y), incy);
816 }
817
DoBlasIamax(Stream * stream,uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<int> * result)818 bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count,
819 const DeviceMemory<float> &x, int incx,
820 DeviceMemory<int> *result) {
821 return DoBlasInternal(wrap::rocblas_isamax, stream,
822 /* pointer_mode_host = */ false, elem_count,
823 GpuMemory(x), incx, GpuMemoryMutable(result));
824 }
825
DoBlasIamax(Stream * stream,uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<int> * result)826 bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count,
827 const DeviceMemory<double> &x, int incx,
828 DeviceMemory<int> *result) {
829 return DoBlasInternal(wrap::rocblas_idamax, stream,
830 /* pointer_mode_host = */ false, elem_count,
831 GpuMemory(x), incx, GpuMemoryMutable(result));
832 }
833
DoBlasIamax(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<int> * result)834 bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count,
835 const DeviceMemory<std::complex<float>> &x, int incx,
836 DeviceMemory<int> *result) {
837 return DoBlasInternal(wrap::rocblas_icamax, stream,
838 /* pointer_mode_host = */ false, elem_count,
839 complex_cast(x), incx, GpuMemoryMutable(result));
840 }
841
DoBlasIamax(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<int> * result)842 bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count,
843 const DeviceMemory<std::complex<double>> &x,
844 int incx, DeviceMemory<int> *result) {
845 return DoBlasInternal(wrap::rocblas_izamax, stream,
846 /* pointer_mode_host = */ false, elem_count,
847 complex_cast(x), incx, GpuMemoryMutable(result));
848 }
849
DoBlasIamin(Stream * stream,uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<int> * result)850 bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count,
851 const DeviceMemory<float> &x, int incx,
852 DeviceMemory<int> *result) {
853 return DoBlasInternal(wrap::rocblas_isamin, stream,
854 /* pointer_mode_host = */ false, elem_count,
855 GpuMemory(x), incx, GpuMemoryMutable(result));
856 }
857
DoBlasIamin(Stream * stream,uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<int> * result)858 bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count,
859 const DeviceMemory<double> &x, int incx,
860 DeviceMemory<int> *result) {
861 return DoBlasInternal(wrap::rocblas_idamin, stream,
862 /* pointer_mode_host = */ false, elem_count,
863 GpuMemory(x), incx, GpuMemoryMutable(result));
864 }
865
DoBlasIamin(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<int> * result)866 bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count,
867 const DeviceMemory<std::complex<float>> &x, int incx,
868 DeviceMemory<int> *result) {
869 return DoBlasInternal(wrap::rocblas_icamin, stream,
870 /* pointer_mode_host = */ false, elem_count,
871 complex_cast(x), incx, GpuMemoryMutable(result));
872 }
873
DoBlasIamin(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<int> * result)874 bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count,
875 const DeviceMemory<std::complex<double>> &x,
876 int incx, DeviceMemory<int> *result) {
877 return DoBlasInternal(wrap::rocblas_izamin, stream,
878 /* pointer_mode_host = */ false, elem_count,
879 complex_cast(x), incx, GpuMemoryMutable(result));
880 }
881
DoBlasGbmv(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)882 bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
883 uint64 n, uint64 kl, uint64 ku, float alpha,
884 const DeviceMemory<float> &a, int lda,
885 const DeviceMemory<float> &x, int incx, float beta,
886 DeviceMemory<float> *y, int incy) {
887 return DoBlasInternal(
888 wrap::rocblas_sgbmv, stream, /* pointer_mode_host = */ true,
889 ROCMBlasTranspose(trans), m, n, kl, ku, &alpha, GpuMemory(a), lda,
890 GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy);
891 }
892
DoBlasGbmv(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)893 bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
894 uint64 n, uint64 kl, uint64 ku, double alpha,
895 const DeviceMemory<double> &a, int lda,
896 const DeviceMemory<double> &x, int incx, double beta,
897 DeviceMemory<double> *y, int incy) {
898 return DoBlasInternal(
899 wrap::rocblas_dgbmv, stream, /* pointer_mode_host = */ true,
900 ROCMBlasTranspose(trans), m, n, kl, ku, &alpha, GpuMemory(a), lda,
901 GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy);
902 }
903
DoBlasGbmv(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)904 bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
905 uint64 n, uint64 kl, uint64 ku,
906 std::complex<float> alpha,
907 const DeviceMemory<std::complex<float>> &a, int lda,
908 const DeviceMemory<std::complex<float>> &x, int incx,
909 std::complex<float> beta,
910 DeviceMemory<std::complex<float>> *y, int incy) {
911 return DoBlasInternal(
912 wrap::rocblas_cgbmv, stream, /* pointer_mode_host = */ true,
913 ROCMBlasTranspose(trans), m, n, kl, ku, complex_cast(alpha),
914 complex_cast(a), lda, complex_cast(x), incx, complex_cast(beta),
915 complex_cast(y), incy);
916 }
917
DoBlasGbmv(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)918 bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
919 uint64 n, uint64 kl, uint64 ku,
920 std::complex<double> alpha,
921 const DeviceMemory<std::complex<double>> &a, int lda,
922 const DeviceMemory<std::complex<double>> &x, int incx,
923 std::complex<double> beta,
924 DeviceMemory<std::complex<double>> *y, int incy) {
925 return DoBlasInternal(
926 wrap::rocblas_zgbmv, stream, /* pointer_mode_host = */ true,
927 ROCMBlasTranspose(trans), m, n, kl, ku, complex_cast(alpha),
928 complex_cast(a), lda, complex_cast(x), incx, complex_cast(beta),
929 complex_cast(y), incy);
930 }
931
DoBlasGemv(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)932 bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
933 uint64 n, float alpha, const DeviceMemory<float> &a,
934 int lda, const DeviceMemory<float> &x, int incx,
935 float beta, DeviceMemory<float> *y, int incy) {
936 blas_log("DoBlasGemv");
937 return DoBlasInternal(
938 wrap::rocblas_sgemv, stream, /* pointer_mode_host = */ true,
939 ROCMBlasTranspose(trans), m, n, &alpha, GpuMemory(a), lda, GpuMemory(x),
940 incx, &beta, GpuMemoryMutable(y), incy);
941 }
942
DoBlasGemv(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)943 bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
944 uint64 n, double alpha, const DeviceMemory<double> &a,
945 int lda, const DeviceMemory<double> &x, int incx,
946 double beta, DeviceMemory<double> *y, int incy) {
947 blas_log("DoBlasGemv");
948 return DoBlasInternal(
949 wrap::rocblas_dgemv, stream, /* pointer_mode_host = */ true,
950 ROCMBlasTranspose(trans), m, n, &alpha, GpuMemory(a), lda, GpuMemory(x),
951 incx, &beta, GpuMemoryMutable(y), incy);
952 }
953
DoBlasGemv(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)954 bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
955 uint64 n, std::complex<float> alpha,
956 const DeviceMemory<std::complex<float>> &a, int lda,
957 const DeviceMemory<std::complex<float>> &x, int incx,
958 std::complex<float> beta,
959 DeviceMemory<std::complex<float>> *y, int incy) {
960 blas_log("DoBlasGemv");
961 return DoBlasInternal(
962 wrap::rocblas_cgemv, stream, /* pointer_mode_host = */ true,
963 ROCMBlasTranspose(trans), m, n, complex_cast(alpha), complex_cast(a), lda,
964 complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy);
965 }
966
DoBlasGemv(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)967 bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
968 uint64 n, std::complex<double> alpha,
969 const DeviceMemory<std::complex<double>> &a, int lda,
970 const DeviceMemory<std::complex<double>> &x, int incx,
971 std::complex<double> beta,
972 DeviceMemory<std::complex<double>> *y, int incy) {
973 blas_log("DoBlasGemv\n");
974 return DoBlasInternal(
975 wrap::rocblas_zgemv, stream, /* pointer_mode_host = */ true,
976 ROCMBlasTranspose(trans), m, n, complex_cast(alpha), complex_cast(a), lda,
977 complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy);
978 }
979
DoBlasGer(Stream * stream,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * a,int lda)980 bool ROCMBlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha,
981 const DeviceMemory<float> &x, int incx,
982 const DeviceMemory<float> &y, int incy,
983 DeviceMemory<float> *a, int lda) {
984 return DoBlasInternal(
985 wrap::rocblas_sger, stream, /* pointer_mode_host = */ true, m, n, &alpha,
986 GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(a), lda);
987 }
988
DoBlasGer(Stream * stream,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * a,int lda)989 bool ROCMBlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha,
990 const DeviceMemory<double> &x, int incx,
991 const DeviceMemory<double> &y, int incy,
992 DeviceMemory<double> *a, int lda) {
993 return DoBlasInternal(
994 wrap::rocblas_dger, stream, /* pointer_mode_host = */ true, m, n, &alpha,
995 GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(a), lda);
996 }
997
DoBlasGerc(Stream * stream,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)998 bool ROCMBlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n,
999 std::complex<float> alpha,
1000 const DeviceMemory<std::complex<float>> &x, int incx,
1001 const DeviceMemory<std::complex<float>> &y, int incy,
1002 DeviceMemory<std::complex<float>> *a, int lda) {
1003 return DoBlasInternal(wrap::rocblas_cgerc, stream,
1004 /* pointer_mode_host = */ true, m, n,
1005 complex_cast(alpha), complex_cast(x), incx,
1006 complex_cast(y), incy, complex_cast(a), lda);
1007 }
1008
DoBlasGerc(Stream * stream,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)1009 bool ROCMBlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n,
1010 std::complex<double> alpha,
1011 const DeviceMemory<std::complex<double>> &x, int incx,
1012 const DeviceMemory<std::complex<double>> &y, int incy,
1013 DeviceMemory<std::complex<double>> *a, int lda) {
1014 return DoBlasInternal(wrap::rocblas_zgerc, stream,
1015 /* pointer_mode_host = */ true, m, n,
1016 complex_cast(alpha), complex_cast(x), incx,
1017 complex_cast(y), incy, complex_cast(a), lda);
1018 }
1019
DoBlasGeru(Stream * stream,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)1020 bool ROCMBlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n,
1021 std::complex<float> alpha,
1022 const DeviceMemory<std::complex<float>> &x, int incx,
1023 const DeviceMemory<std::complex<float>> &y, int incy,
1024 DeviceMemory<std::complex<float>> *a, int lda) {
1025 return DoBlasInternal(wrap::rocblas_cgeru, stream,
1026 /* pointer_mode_host = */ true, m, n,
1027 complex_cast(alpha), complex_cast(x), incx,
1028 complex_cast(y), incy, complex_cast(a), lda);
1029 }
1030
DoBlasGeru(Stream * stream,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)1031 bool ROCMBlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n,
1032 std::complex<double> alpha,
1033 const DeviceMemory<std::complex<double>> &x, int incx,
1034 const DeviceMemory<std::complex<double>> &y, int incy,
1035 DeviceMemory<std::complex<double>> *a, int lda) {
1036 return DoBlasInternal(wrap::rocblas_zgeru, stream,
1037 /* pointer_mode_host = */ true, m, n,
1038 complex_cast(alpha), complex_cast(x), incx,
1039 complex_cast(y), incy, complex_cast(a), lda);
1040 }
1041
DoBlasHbmv(Stream * stream,blas::UpperLower uplo,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)1042 bool ROCMBlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1043 uint64 k, std::complex<float> alpha,
1044 const DeviceMemory<std::complex<float>> &a, int lda,
1045 const DeviceMemory<std::complex<float>> &x, int incx,
1046 std::complex<float> beta,
1047 DeviceMemory<std::complex<float>> *y, int incy) {
1048 return DoBlasInternal(
1049 wrap::rocblas_chbmv, stream, /* pointer_mode_host = */ true,
1050 ROCMBlasUpperLower(uplo), n, k, complex_cast(alpha), complex_cast(a), lda,
1051 complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy);
1052 }
1053
DoBlasHbmv(Stream * stream,blas::UpperLower uplo,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)1054 bool ROCMBlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1055 uint64 k, std::complex<double> alpha,
1056 const DeviceMemory<std::complex<double>> &a, int lda,
1057 const DeviceMemory<std::complex<double>> &x, int incx,
1058 std::complex<double> beta,
1059 DeviceMemory<std::complex<double>> *y, int incy) {
1060 return DoBlasInternal(
1061 wrap::rocblas_zhbmv, stream, /* pointer_mode_host = */ true,
1062 ROCMBlasUpperLower(uplo), n, k, complex_cast(alpha), complex_cast(a), lda,
1063 complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy);
1064 }
1065
DoBlasHemv(Stream * stream,blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)1066 bool ROCMBlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
1067 std::complex<float> alpha,
1068 const DeviceMemory<std::complex<float>> &a, int lda,
1069 const DeviceMemory<std::complex<float>> &x, int incx,
1070 std::complex<float> beta,
1071 DeviceMemory<std::complex<float>> *y, int incy) {
1072 return DoBlasInternal(
1073 wrap::rocblas_chemv, stream, /* pointer_mode_host = */ true,
1074 ROCMBlasUpperLower(uplo), n, complex_cast(alpha), complex_cast(a), lda,
1075 complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy);
1076 }
1077
DoBlasHemv(Stream * stream,blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)1078 bool ROCMBlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
1079 std::complex<double> alpha,
1080 const DeviceMemory<std::complex<double>> &a, int lda,
1081 const DeviceMemory<std::complex<double>> &x, int incx,
1082 std::complex<double> beta,
1083 DeviceMemory<std::complex<double>> *y, int incy) {
1084 return DoBlasInternal(
1085 wrap::rocblas_zhemv, stream, /* pointer_mode_host = */ true,
1086 ROCMBlasUpperLower(uplo), n, complex_cast(alpha), complex_cast(a), lda,
1087 complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy);
1088 }
1089
DoBlasHer(Stream * stream,blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * a,int lda)1090 bool ROCMBlas::DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
1091 float alpha,
1092 const DeviceMemory<std::complex<float>> &x, int incx,
1093 DeviceMemory<std::complex<float>> *a, int lda) {
1094 return DoBlasInternal(wrap::rocblas_cher, stream,
1095 /* pointer_mode_host = */ true,
1096 ROCMBlasUpperLower(uplo), n, complex_cast(alpha),
1097 complex_cast(x), incx, complex_cast(a), lda);
1098 }
1099
DoBlasHer(Stream * stream,blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * a,int lda)1100 bool ROCMBlas::DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
1101 double alpha,
1102 const DeviceMemory<std::complex<double>> &x, int incx,
1103 DeviceMemory<std::complex<double>> *a, int lda) {
1104 return DoBlasInternal(wrap::rocblas_zher, stream,
1105 /* pointer_mode_host = */ true,
1106 ROCMBlasUpperLower(uplo), n, complex_cast(alpha),
1107 complex_cast(x), incx, complex_cast(a), lda);
1108 }
1109
DoBlasHer2(Stream * stream,blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)1110 bool ROCMBlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
1111 std::complex<float> alpha,
1112 const DeviceMemory<std::complex<float>> &x, int incx,
1113 const DeviceMemory<std::complex<float>> &y, int incy,
1114 DeviceMemory<std::complex<float>> *a, int lda) {
1115 return DoBlasInternal(
1116 wrap::rocblas_cher2, stream, /* pointer_mode_host = */ true,
1117 ROCMBlasUpperLower(uplo), n, complex_cast(alpha), complex_cast(x), incx,
1118 complex_cast(y), incy, complex_cast(a), lda);
1119 }
1120
DoBlasHer2(Stream * stream,blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)1121 bool ROCMBlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
1122 std::complex<double> alpha,
1123 const DeviceMemory<std::complex<double>> &x, int incx,
1124 const DeviceMemory<std::complex<double>> &y, int incy,
1125 DeviceMemory<std::complex<double>> *a, int lda) {
1126 return DoBlasInternal(
1127 wrap::rocblas_zher2, stream, /* pointer_mode_host = */ true,
1128 ROCMBlasUpperLower(uplo), n, complex_cast(alpha), complex_cast(x), incx,
1129 complex_cast(y), incy, complex_cast(a), lda);
1130 }
1131
DoBlasHpmv(Stream * stream,blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & ap,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)1132 bool ROCMBlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1133 std::complex<float> alpha,
1134 const DeviceMemory<std::complex<float>> &ap,
1135 const DeviceMemory<std::complex<float>> &x, int incx,
1136 std::complex<float> beta,
1137 DeviceMemory<std::complex<float>> *y, int incy) {
1138 return DoBlasInternal(
1139 wrap::rocblas_chpmv, stream, /* pointer_mode_host = */ true,
1140 ROCMBlasUpperLower(uplo), n, complex_cast(alpha), complex_cast(ap),
1141 complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy);
1142 }
1143
DoBlasHpmv(Stream * stream,blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & ap,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)1144 bool ROCMBlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1145 std::complex<double> alpha,
1146 const DeviceMemory<std::complex<double>> &ap,
1147 const DeviceMemory<std::complex<double>> &x, int incx,
1148 std::complex<double> beta,
1149 DeviceMemory<std::complex<double>> *y, int incy) {
1150 return DoBlasInternal(
1151 wrap::rocblas_zhpmv, stream, /* pointer_mode_host = */ true,
1152 ROCMBlasUpperLower(uplo), n, complex_cast(alpha), complex_cast(ap),
1153 complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy);
1154 }
1155
DoBlasHpr(Stream * stream,blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * ap)1156 bool ROCMBlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
1157 float alpha,
1158 const DeviceMemory<std::complex<float>> &x, int incx,
1159 DeviceMemory<std::complex<float>> *ap) {
1160 return DoBlasInternal(wrap::rocblas_chpr, stream,
1161 /* pointer_mode_host = */ true,
1162 ROCMBlasUpperLower(uplo), n, complex_cast(alpha),
1163 complex_cast(x), incx, complex_cast(ap));
1164 }
1165
DoBlasHpr(Stream * stream,blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * ap)1166 bool ROCMBlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
1167 double alpha,
1168 const DeviceMemory<std::complex<double>> &x, int incx,
1169 DeviceMemory<std::complex<double>> *ap) {
1170 return DoBlasInternal(wrap::rocblas_zhpr, stream,
1171 /* pointer_mode_host = */ true,
1172 ROCMBlasUpperLower(uplo), n, complex_cast(alpha),
1173 complex_cast(x), incx, complex_cast(ap));
1174 }
1175
DoBlasHpr2(Stream * stream,blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * ap)1176 bool ROCMBlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1177 std::complex<float> alpha,
1178 const DeviceMemory<std::complex<float>> &x, int incx,
1179 const DeviceMemory<std::complex<float>> &y, int incy,
1180 DeviceMemory<std::complex<float>> *ap) {
1181 return DoBlasInternal(
1182 wrap::rocblas_chpr2, stream, /* pointer_mode_host = */ true,
1183 ROCMBlasUpperLower(uplo), n, complex_cast(alpha), complex_cast(x), incx,
1184 complex_cast(y), incy, complex_cast(ap));
1185 }
1186
DoBlasHpr2(Stream * stream,blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * ap)1187 bool ROCMBlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1188 std::complex<double> alpha,
1189 const DeviceMemory<std::complex<double>> &x, int incx,
1190 const DeviceMemory<std::complex<double>> &y, int incy,
1191 DeviceMemory<std::complex<double>> *ap) {
1192 return DoBlasInternal(
1193 wrap::rocblas_zhpr2, stream, /* pointer_mode_host = */ true,
1194 ROCMBlasUpperLower(uplo), n, complex_cast(alpha), complex_cast(x), incx,
1195 complex_cast(y), incy, complex_cast(ap));
1196 }
1197
DoBlasSbmv(Stream * stream,blas::UpperLower uplo,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)1198 bool ROCMBlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1199 uint64 k, float alpha, const DeviceMemory<float> &a,
1200 int lda, const DeviceMemory<float> &x, int incx,
1201 float beta, DeviceMemory<float> *y, int incy) {
1202 return DoBlasInternal(
1203 wrap::rocblas_ssbmv, stream, /* pointer_mode_host = */ true,
1204 ROCMBlasUpperLower(uplo), n, k, &alpha, GpuMemory(a), lda, GpuMemory(x),
1205 incx, &beta, GpuMemoryMutable(y), incy);
1206 }
1207
DoBlasSbmv(Stream * stream,blas::UpperLower uplo,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)1208 bool ROCMBlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1209 uint64 k, double alpha, const DeviceMemory<double> &a,
1210 int lda, const DeviceMemory<double> &x, int incx,
1211 double beta, DeviceMemory<double> *y, int incy) {
1212 return DoBlasInternal(
1213 wrap::rocblas_dsbmv, stream, /* pointer_mode_host = */ true,
1214 ROCMBlasUpperLower(uplo), n, k, &alpha, GpuMemory(a), lda, GpuMemory(x),
1215 incx, &beta, GpuMemoryMutable(y), incy);
1216 }
1217
DoBlasSpmv(Stream * stream,blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & ap,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)1218 bool ROCMBlas::DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1219 float alpha, const DeviceMemory<float> &ap,
1220 const DeviceMemory<float> &x, int incx, float beta,
1221 DeviceMemory<float> *y, int incy) {
1222 return DoBlasInternal(wrap::rocblas_sspmv, stream,
1223 /* pointer_mode_host = */ true,
1224 ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(ap),
1225 GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy);
1226 }
1227
DoBlasSpmv(Stream * stream,blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & ap,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)1228 bool ROCMBlas::DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1229 double alpha, const DeviceMemory<double> &ap,
1230 const DeviceMemory<double> &x, int incx, double beta,
1231 DeviceMemory<double> *y, int incy) {
1232 return DoBlasInternal(wrap::rocblas_dspmv, stream,
1233 /* pointer_mode_host = */ true,
1234 ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(ap),
1235 GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy);
1236 }
1237
DoBlasSpr(Stream * stream,blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * ap)1238 bool ROCMBlas::DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
1239 float alpha, const DeviceMemory<float> &x, int incx,
1240 DeviceMemory<float> *ap) {
1241 return DoBlasInternal(wrap::rocblas_sspr, stream,
1242 /* pointer_mode_host = */ true,
1243 ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1244 GpuMemoryMutable(ap));
1245 }
1246
DoBlasSpr(Stream * stream,blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * ap)1247 bool ROCMBlas::DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
1248 double alpha, const DeviceMemory<double> &x, int incx,
1249 DeviceMemory<double> *ap) {
1250 return DoBlasInternal(wrap::rocblas_dspr, stream,
1251 /* pointer_mode_host = */ true,
1252 ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1253 GpuMemoryMutable(ap));
1254 }
1255
DoBlasSpr2(Stream * stream,blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * ap)1256 bool ROCMBlas::DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1257 float alpha, const DeviceMemory<float> &x, int incx,
1258 const DeviceMemory<float> &y, int incy,
1259 DeviceMemory<float> *ap) {
1260 return DoBlasInternal(wrap::rocblas_sspr2, stream,
1261 /* pointer_mode_host = */ true,
1262 ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1263 GpuMemory(y), incy, GpuMemoryMutable(ap));
1264 }
1265
DoBlasSpr2(Stream * stream,blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * ap)1266 bool ROCMBlas::DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1267 double alpha, const DeviceMemory<double> &x, int incx,
1268 const DeviceMemory<double> &y, int incy,
1269 DeviceMemory<double> *ap) {
1270 return DoBlasInternal(wrap::rocblas_dspr2, stream,
1271 /* pointer_mode_host = */ true,
1272 ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1273 GpuMemory(y), incy, GpuMemoryMutable(ap));
1274 }
1275
DoBlasSymv(Stream * stream,blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)1276 bool ROCMBlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
1277 float alpha, const DeviceMemory<float> &a, int lda,
1278 const DeviceMemory<float> &x, int incx, float beta,
1279 DeviceMemory<float> *y, int incy) {
1280 return DoBlasInternal(wrap::rocblas_ssymv, stream,
1281 /* pointer_mode_host = */ true,
1282 ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(a), lda,
1283 GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy);
1284 }
1285
DoBlasSymv(Stream * stream,blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)1286 bool ROCMBlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
1287 double alpha, const DeviceMemory<double> &a, int lda,
1288 const DeviceMemory<double> &x, int incx, double beta,
1289 DeviceMemory<double> *y, int incy) {
1290 return DoBlasInternal(wrap::rocblas_dsymv, stream,
1291 /* pointer_mode_host = */ true,
1292 ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(a), lda,
1293 GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy);
1294 }
1295
DoBlasSyr(Stream * stream,blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * a,int lda)1296 bool ROCMBlas::DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
1297 float alpha, const DeviceMemory<float> &x, int incx,
1298 DeviceMemory<float> *a, int lda) {
1299 return DoBlasInternal(wrap::rocblas_ssyr, stream,
1300 /* pointer_mode_host = */ true,
1301 ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1302 GpuMemoryMutable(a), lda);
1303 }
1304
DoBlasSyr(Stream * stream,blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * a,int lda)1305 bool ROCMBlas::DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
1306 double alpha, const DeviceMemory<double> &x, int incx,
1307 DeviceMemory<double> *a, int lda) {
1308 return DoBlasInternal(wrap::rocblas_dsyr, stream,
1309 /* pointer_mode_host = */ true,
1310 ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1311 GpuMemoryMutable(a), lda);
1312 }
1313
DoBlasSyr2(Stream * stream,blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * a,int lda)1314 bool ROCMBlas::DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1315 float alpha, const DeviceMemory<float> &x, int incx,
1316 const DeviceMemory<float> &y, int incy,
1317 DeviceMemory<float> *a, int lda) {
1318 return DoBlasInternal(wrap::rocblas_ssyr2, stream,
1319 /* pointer_mode_host = */ true,
1320 ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1321 GpuMemory(y), incy, GpuMemoryMutable(a), lda);
1322 }
1323
DoBlasSyr2(Stream * stream,blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * a,int lda)1324 bool ROCMBlas::DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1325 double alpha, const DeviceMemory<double> &x, int incx,
1326 const DeviceMemory<double> &y, int incy,
1327 DeviceMemory<double> *a, int lda) {
1328 return DoBlasInternal(wrap::rocblas_dsyr2, stream,
1329 /* pointer_mode_host = */ true,
1330 ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1331 GpuMemory(y), incy, GpuMemoryMutable(a), lda);
1332 }
1333
DoBlasTbmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)1334 bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
1335 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1336 uint64 k, const DeviceMemory<float> &a, int lda,
1337 DeviceMemory<float> *x, int incx) {
1338 return DoBlasInternal(wrap::rocblas_stbmv, stream,
1339 /* pointer_mode_host = */ false,
1340 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1341 ROCMBlasDiagonal(diag), n, k, GpuMemory(a), lda,
1342 GpuMemoryMutable(x), incx);
1343 }
1344
DoBlasTbmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)1345 bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
1346 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1347 uint64 k, const DeviceMemory<double> &a, int lda,
1348 DeviceMemory<double> *x, int incx) {
1349 return DoBlasInternal(wrap::rocblas_dtbmv, stream,
1350 /* pointer_mode_host = */ false,
1351 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1352 ROCMBlasDiagonal(diag), n, k, GpuMemory(a), lda,
1353 GpuMemoryMutable(x), incx);
1354 }
1355
DoBlasTbmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)1356 bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
1357 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1358 uint64 k, const DeviceMemory<std::complex<float>> &a,
1359 int lda, DeviceMemory<std::complex<float>> *x,
1360 int incx) {
1361 return DoBlasInternal(wrap::rocblas_ctbmv, stream,
1362 /* pointer_mode_host = */ false,
1363 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1364 ROCMBlasDiagonal(diag), n, k, complex_cast(a), lda,
1365 complex_cast(x), incx);
1366 }
1367
DoBlasTbmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)1368 bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
1369 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1370 uint64 k, const DeviceMemory<std::complex<double>> &a,
1371 int lda, DeviceMemory<std::complex<double>> *x,
1372 int incx) {
1373 return DoBlasInternal(wrap::rocblas_ztbmv, stream,
1374 /* pointer_mode_host = */ false,
1375 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1376 ROCMBlasDiagonal(diag), n, k, complex_cast(a), lda,
1377 complex_cast(x), incx);
1378 }
1379
DoBlasTbsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)1380 bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
1381 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1382 uint64 k, const DeviceMemory<float> &a, int lda,
1383 DeviceMemory<float> *x, int incx) {
1384 return DoBlasInternal(wrap::rocblas_stbsv, stream,
1385 /* pointer_mode_host = */ false,
1386 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1387 ROCMBlasDiagonal(diag), n, k, GpuMemory(a), lda,
1388 GpuMemoryMutable(x), incx);
1389 }
1390
DoBlasTbsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)1391 bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
1392 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1393 uint64 k, const DeviceMemory<double> &a, int lda,
1394 DeviceMemory<double> *x, int incx) {
1395 return DoBlasInternal(wrap::rocblas_dtbsv, stream,
1396 /* pointer_mode_host = */ false,
1397 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1398 ROCMBlasDiagonal(diag), n, k, GpuMemory(a), lda,
1399 GpuMemoryMutable(x), incx);
1400 }
1401
DoBlasTbsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)1402 bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
1403 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1404 uint64 k, const DeviceMemory<std::complex<float>> &a,
1405 int lda, DeviceMemory<std::complex<float>> *x,
1406 int incx) {
1407 return DoBlasInternal(wrap::rocblas_ctbsv, stream,
1408 /* pointer_mode_host = */ false,
1409 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1410 ROCMBlasDiagonal(diag), n, k, complex_cast(a), lda,
1411 complex_cast(x), incx);
1412 }
1413
DoBlasTbsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)1414 bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
1415 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1416 uint64 k, const DeviceMemory<std::complex<double>> &a,
1417 int lda, DeviceMemory<std::complex<double>> *x,
1418 int incx) {
1419 return DoBlasInternal(wrap::rocblas_ztbsv, stream,
1420 /* pointer_mode_host = */ false,
1421 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1422 ROCMBlasDiagonal(diag), n, k, complex_cast(a), lda,
1423 complex_cast(x), incx);
1424 }
1425
DoBlasTpmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & ap,DeviceMemory<float> * x,int incx)1426 bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
1427 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1428 const DeviceMemory<float> &ap, DeviceMemory<float> *x,
1429 int incx) {
1430 return DoBlasInternal(
1431 wrap::rocblas_stpmv, stream, /* pointer_mode_host = */ false,
1432 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1433 ROCMBlasDiagonal(diag), n, GpuMemory(ap), GpuMemoryMutable(x), incx);
1434 }
1435
DoBlasTpmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & ap,DeviceMemory<double> * x,int incx)1436 bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
1437 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1438 const DeviceMemory<double> &ap,
1439 DeviceMemory<double> *x, int incx) {
1440 return DoBlasInternal(
1441 wrap::rocblas_dtpmv, stream, /* pointer_mode_host = */ false,
1442 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1443 ROCMBlasDiagonal(diag), n, GpuMemory(ap), GpuMemoryMutable(x), incx);
1444 }
1445
DoBlasTpmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & ap,DeviceMemory<std::complex<float>> * x,int incx)1446 bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
1447 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1448 const DeviceMemory<std::complex<float>> &ap,
1449 DeviceMemory<std::complex<float>> *x, int incx) {
1450 return DoBlasInternal(
1451 wrap::rocblas_ctpmv, stream, /* pointer_mode_host = */ false,
1452 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1453 ROCMBlasDiagonal(diag), n, complex_cast(ap), complex_cast(x), incx);
1454 }
1455
DoBlasTpmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & ap,DeviceMemory<std::complex<double>> * x,int incx)1456 bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
1457 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1458 const DeviceMemory<std::complex<double>> &ap,
1459 DeviceMemory<std::complex<double>> *x, int incx) {
1460 return DoBlasInternal(
1461 wrap::rocblas_ztpmv, stream, /* pointer_mode_host = */ false,
1462 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1463 ROCMBlasDiagonal(diag), n, complex_cast(ap), complex_cast(x), incx);
1464 }
1465
DoBlasTpsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & ap,DeviceMemory<float> * x,int incx)1466 bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
1467 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1468 const DeviceMemory<float> &ap, DeviceMemory<float> *x,
1469 int incx) {
1470 return DoBlasInternal(
1471 wrap::rocblas_stpsv, stream, /* pointer_mode_host = */ false,
1472 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1473 ROCMBlasDiagonal(diag), n, GpuMemory(ap), GpuMemoryMutable(x), incx);
1474 }
1475
DoBlasTpsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & ap,DeviceMemory<double> * x,int incx)1476 bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
1477 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1478 const DeviceMemory<double> &ap,
1479 DeviceMemory<double> *x, int incx) {
1480 return DoBlasInternal(
1481 wrap::rocblas_dtpsv, stream, /* pointer_mode_host = */ false,
1482 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1483 ROCMBlasDiagonal(diag), n, GpuMemory(ap), GpuMemoryMutable(x), incx);
1484 }
1485
DoBlasTpsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & ap,DeviceMemory<std::complex<float>> * x,int incx)1486 bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
1487 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1488 const DeviceMemory<std::complex<float>> &ap,
1489 DeviceMemory<std::complex<float>> *x, int incx) {
1490 return DoBlasInternal(
1491 wrap::rocblas_ctpsv, stream, /* pointer_mode_host = */ false,
1492 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1493 ROCMBlasDiagonal(diag), n, complex_cast(ap), complex_cast(x), incx);
1494 }
1495
DoBlasTpsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & ap,DeviceMemory<std::complex<double>> * x,int incx)1496 bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
1497 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1498 const DeviceMemory<std::complex<double>> &ap,
1499 DeviceMemory<std::complex<double>> *x, int incx) {
1500 return DoBlasInternal(
1501 wrap::rocblas_ztpsv, stream, /* pointer_mode_host = */ false,
1502 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1503 ROCMBlasDiagonal(diag), n, complex_cast(ap), complex_cast(x), incx);
1504 }
1505
DoBlasTrmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)1506 bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
1507 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1508 const DeviceMemory<float> &a, int lda,
1509 DeviceMemory<float> *x, int incx) {
1510 return DoBlasInternal(
1511 wrap::rocblas_strmv, stream, /* pointer_mode_host = */ false,
1512 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1513 ROCMBlasDiagonal(diag), n, GpuMemory(a), lda, GpuMemoryMutable(x), incx);
1514 }
1515
DoBlasTrmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)1516 bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
1517 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1518 const DeviceMemory<double> &a, int lda,
1519 DeviceMemory<double> *x, int incx) {
1520 return DoBlasInternal(
1521 wrap::rocblas_dtrmv, stream, /* pointer_mode_host = */ false,
1522 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1523 ROCMBlasDiagonal(diag), n, GpuMemory(a), lda, GpuMemoryMutable(x), incx);
1524 }
1525
DoBlasTrmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)1526 bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
1527 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1528 const DeviceMemory<std::complex<float>> &a, int lda,
1529 DeviceMemory<std::complex<float>> *x, int incx) {
1530 return DoBlasInternal(
1531 wrap::rocblas_ctrmv, stream, /* pointer_mode_host = */ false,
1532 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1533 ROCMBlasDiagonal(diag), n, complex_cast(a), lda, complex_cast(x), incx);
1534 }
1535
DoBlasTrmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)1536 bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
1537 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1538 const DeviceMemory<std::complex<double>> &a, int lda,
1539 DeviceMemory<std::complex<double>> *x, int incx) {
1540 return DoBlasInternal(
1541 wrap::rocblas_ztrmv, stream, /* pointer_mode_host = */ false,
1542 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1543 ROCMBlasDiagonal(diag), n, complex_cast(a), lda, complex_cast(x), incx);
1544 }
1545
DoBlasTrsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)1546 bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1547 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1548 const DeviceMemory<float> &a, int lda,
1549 DeviceMemory<float> *x, int incx) {
1550 return DoBlasInternal(
1551 wrap::rocblas_strsv, stream, /* pointer_mode_host = */ false,
1552 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1553 ROCMBlasDiagonal(diag), n, GpuMemory(a), lda, GpuMemoryMutable(x), incx);
1554 }
1555
DoBlasTrsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)1556 bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1557 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1558 const DeviceMemory<double> &a, int lda,
1559 DeviceMemory<double> *x, int incx) {
1560 return DoBlasInternal(
1561 wrap::rocblas_dtrsv, stream, /* pointer_mode_host = */ false,
1562 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1563 ROCMBlasDiagonal(diag), n, GpuMemory(a), lda, GpuMemoryMutable(x), incx);
1564 }
1565
DoBlasTrsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)1566 bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1567 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1568 const DeviceMemory<std::complex<float>> &a, int lda,
1569 DeviceMemory<std::complex<float>> *x, int incx) {
1570 return DoBlasInternal(
1571 wrap::rocblas_ctrsv, stream, /* pointer_mode_host = */ false,
1572 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1573 ROCMBlasDiagonal(diag), n, complex_cast(a), lda, complex_cast(x), incx);
1574 }
1575
DoBlasTrsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)1576 bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1577 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1578 const DeviceMemory<std::complex<double>> &a, int lda,
1579 DeviceMemory<std::complex<double>> *x, int incx) {
1580 return DoBlasInternal(
1581 wrap::rocblas_ztrsv, stream, /* pointer_mode_host = */ false,
1582 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1583 ROCMBlasDiagonal(diag), n, complex_cast(a), lda, complex_cast(x), incx);
1584 }
1585
DoBlasGemm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,float beta,DeviceMemory<Eigen::half> * c,int ldc)1586 bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
1587 blas::Transpose transb, uint64 m, uint64 n, uint64 k,
1588 float alpha, const DeviceMemory<Eigen::half> &a,
1589 int lda, const DeviceMemory<Eigen::half> &b, int ldb,
1590 float beta, DeviceMemory<Eigen::half> *c, int ldc) {
1591 blas_log("DoBlasGemm");
1592 VLOG(1) << absl::StreamFormat(
1593 "doing rocBLAS SGEMM<half>: at=%d bt=%d m=%u n=%u "
1594 "k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
1595 "c=%p ldc=%d",
1596 static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
1597 a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
1598 if (transa == blas::Transpose::kNoTranspose) {
1599 if (lda < static_cast<int64>(m)) {
1600 LOG(WARNING) << "GEMM lda was smaller than m (no transpose case); "
1601 "precondition violation";
1602 }
1603 } else {
1604 if (lda < static_cast<int64>(k)) {
1605 LOG(WARNING) << "GEMM lda (" << lda << ") was smaller than k (" << k
1606 << ") (transpose case); precondition violation";
1607 }
1608 }
1609 if (transb == blas::Transpose::kNoTranspose) {
1610 if (ldb < static_cast<int64>(k)) {
1611 LOG(WARNING) << "GEMM ldb (" << ldb << ") was smaller than k (" << k
1612 << ") (no transpose case); precondition violation";
1613 }
1614 } else {
1615 if (ldb < static_cast<int64>(n)) {
1616 LOG(WARNING) << "GEMM ldb was smaller than n (transpose case); "
1617 "precondition violation";
1618 }
1619 }
1620 const Eigen::half alpha_half(alpha);
1621 const Eigen::half beta_half(beta);
1622 return DoBlasInternal(
1623 wrap::rocblas_hgemm, stream, /* pointer_mode_host = */ true,
1624 ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
1625 reinterpret_cast<const rocblas_half *>(&alpha_half),
1626 reinterpret_cast<const rocblas_half *>(GpuMemory(a)), lda,
1627 reinterpret_cast<const rocblas_half *>(GpuMemory(b)), ldb,
1628 reinterpret_cast<const rocblas_half *>(&beta_half),
1629 reinterpret_cast<rocblas_half *>(GpuMemoryMutable(c)), ldc);
1630 }
1631
DoBlasGemm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)1632 bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
1633 blas::Transpose transb, uint64 m, uint64 n, uint64 k,
1634 float alpha, const DeviceMemory<float> &a, int lda,
1635 const DeviceMemory<float> &b, int ldb, float beta,
1636 DeviceMemory<float> *c, int ldc) {
1637 blas_log("DoBlasGemm");
1638 VLOG(1) << absl::StreamFormat(
1639 "doing rocBLAS SGEMM<float>: at=%d bt=%d m=%u n=%u "
1640 "k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
1641 "c=%p ldc=%d",
1642 static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
1643 a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
1644 if (transa == blas::Transpose::kNoTranspose) {
1645 if (lda < static_cast<int64>(m)) {
1646 LOG(WARNING) << "GEMM lda was smaller than m (no transpose case); "
1647 "precondition violation";
1648 }
1649 } else {
1650 if (lda < static_cast<int64>(k)) {
1651 LOG(WARNING) << "GEMM lda (" << lda << ") was smaller than k (" << k
1652 << ") (transpose case); precondition violation";
1653 }
1654 }
1655 if (transb == blas::Transpose::kNoTranspose) {
1656 if (ldb < static_cast<int64>(k)) {
1657 LOG(WARNING) << "GEMM ldb (" << ldb << ") was smaller than k (" << k
1658 << ") (no transpose case); precondition violation";
1659 }
1660 } else {
1661 if (ldb < static_cast<int64>(n)) {
1662 LOG(WARNING) << "GEMM ldb was smaller than n (transpose case); "
1663 "precondition violation";
1664 }
1665 }
1666 return DoBlasInternal(
1667 wrap::rocblas_sgemm, stream, /* pointer_mode_host = */ true,
1668 ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, &alpha,
1669 GpuMemory(a), lda, GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc);
1670 }
1671
DoBlasGemm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)1672 bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
1673 blas::Transpose transb, uint64 m, uint64 n, uint64 k,
1674 double alpha, const DeviceMemory<double> &a, int lda,
1675 const DeviceMemory<double> &b, int ldb, double beta,
1676 DeviceMemory<double> *c, int ldc) {
1677 blas_log("DoBlasGemm");
1678 return DoBlasInternal(
1679 wrap::rocblas_dgemm, stream, /* pointer_mode_host = */ true,
1680 ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, &alpha,
1681 GpuMemory(a), lda, GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc);
1682 }
1683
DoBlasGemm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)1684 bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
1685 blas::Transpose transb, uint64 m, uint64 n, uint64 k,
1686 std::complex<float> alpha,
1687 const DeviceMemory<std::complex<float>> &a, int lda,
1688 const DeviceMemory<std::complex<float>> &b, int ldb,
1689 std::complex<float> beta,
1690 DeviceMemory<std::complex<float>> *c, int ldc) {
1691 blas_log("DoBlasGemm");
1692 return DoBlasInternal(
1693 wrap::rocblas_cgemm, stream, /* pointer_mode_host = */ true,
1694 ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
1695 complex_cast(alpha), complex_cast(a), lda, complex_cast(b), ldb,
1696 complex_cast(beta), complex_cast(c), ldc);
1697 }
1698
DoBlasGemm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)1699 bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
1700 blas::Transpose transb, uint64 m, uint64 n, uint64 k,
1701 std::complex<double> alpha,
1702 const DeviceMemory<std::complex<double>> &a, int lda,
1703 const DeviceMemory<std::complex<double>> &b, int ldb,
1704 std::complex<double> beta,
1705 DeviceMemory<std::complex<double>> *c, int ldc) {
1706 blas_log("DoBlasGemm");
1707 return DoBlasInternal(
1708 wrap::rocblas_zgemm, stream, /* pointer_mode_host = */ true,
1709 ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
1710 complex_cast(alpha), complex_cast(a), lda, complex_cast(b), ldb,
1711 complex_cast(beta), complex_cast(c), ldc);
1712 }
1713
DoBlasGemvWithProfiling(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy,blas::ProfileResult * output_profile_result)1714 bool ROCMBlas::DoBlasGemvWithProfiling(
1715 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha,
1716 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
1717 int incx, float beta, DeviceMemory<float> *y, int incy,
1718 blas::ProfileResult *output_profile_result) {
1719 return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
1720 incx, beta, y, incy,
1721 output_profile_result);
1722 }
1723
DoBlasGemvWithProfiling(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy,blas::ProfileResult * output_profile_result)1724 bool ROCMBlas::DoBlasGemvWithProfiling(
1725 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha,
1726 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
1727 int incx, double beta, DeviceMemory<double> *y, int incy,
1728 blas::ProfileResult *output_profile_result) {
1729 return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
1730 incx, beta, y, incy,
1731 output_profile_result);
1732 }
1733
DoBlasGemvWithProfiling(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy,blas::ProfileResult * output_profile_result)1734 bool ROCMBlas::DoBlasGemvWithProfiling(
1735 Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
1736 std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a,
1737 int lda, const DeviceMemory<std::complex<float>> &x, int incx,
1738 std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
1739 blas::ProfileResult *output_profile_result) {
1740 return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
1741 incx, beta, y, incy,
1742 output_profile_result);
1743 }
1744
DoBlasGemvWithProfiling(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy,blas::ProfileResult * output_profile_result)1745 bool ROCMBlas::DoBlasGemvWithProfiling(
1746 Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
1747 std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a,
1748 int lda, const DeviceMemory<std::complex<double>> &x, int incx,
1749 std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy,
1750 blas::ProfileResult *output_profile_result) {
1751 return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
1752 incx, beta, y, incy,
1753 output_profile_result);
1754 }
1755
DoBlasGemmWithProfiling(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,float beta,DeviceMemory<Eigen::half> * c,int ldc,blas::ProfileResult * output_profile_result)1756 bool ROCMBlas::DoBlasGemmWithProfiling(
1757 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1758 uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
1759 int lda, const DeviceMemory<Eigen::half> &b, int ldb, float beta,
1760 DeviceMemory<Eigen::half> *c, int ldc,
1761 blas::ProfileResult *output_profile_result) {
1762 return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1763 lda, b, ldb, beta, c, ldc,
1764 output_profile_result);
1765 }
1766
DoBlasGemmWithProfiling(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc,blas::ProfileResult * output_profile_result)1767 bool ROCMBlas::DoBlasGemmWithProfiling(
1768 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1769 uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
1770 const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
1771 int ldc, blas::ProfileResult *output_profile_result) {
1772 return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1773 lda, b, ldb, beta, c, ldc,
1774 output_profile_result);
1775 }
1776
DoBlasGemmWithProfiling(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc,blas::ProfileResult * output_profile_result)1777 bool ROCMBlas::DoBlasGemmWithProfiling(
1778 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1779 uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
1780 const DeviceMemory<double> &b, int ldb, double beta,
1781 DeviceMemory<double> *c, int ldc,
1782 blas::ProfileResult *output_profile_result) {
1783 return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1784 lda, b, ldb, beta, c, ldc,
1785 output_profile_result);
1786 }
1787
DoBlasGemmWithProfiling(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc,blas::ProfileResult * output_profile_result)1788 bool ROCMBlas::DoBlasGemmWithProfiling(
1789 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1790 uint64 n, uint64 k, std::complex<float> alpha,
1791 const DeviceMemory<std::complex<float>> &a, int lda,
1792 const DeviceMemory<std::complex<float>> &b, int ldb,
1793 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
1794 blas::ProfileResult *output_profile_result) {
1795 return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1796 lda, b, ldb, beta, c, ldc,
1797 output_profile_result);
1798 }
1799
DoBlasGemmWithProfiling(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc,blas::ProfileResult * output_profile_result)1800 bool ROCMBlas::DoBlasGemmWithProfiling(
1801 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1802 uint64 n, uint64 k, std::complex<double> alpha,
1803 const DeviceMemory<std::complex<double>> &a, int lda,
1804 const DeviceMemory<std::complex<double>> &b, int ldb,
1805 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
1806 blas::ProfileResult *output_profile_result) {
1807 return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1808 lda, b, ldb, beta, c, ldc,
1809 output_profile_result);
1810 }
1811
1812 template <typename T>
DoBlasGemvWithProfilingImpl(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,const T & alpha,const DeviceMemory<T> & a,int lda,const DeviceMemory<T> & x,int incx,const T & beta,DeviceMemory<T> * y,int incy,blas::ProfileResult * output_profile_result)1813 bool ROCMBlas::DoBlasGemvWithProfilingImpl(
1814 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, const T &alpha,
1815 const DeviceMemory<T> &a, int lda, const DeviceMemory<T> &x, int incx,
1816 const T &beta, DeviceMemory<T> *y, int incy,
1817 blas::ProfileResult *output_profile_result) {
1818 // ROCM TODO: properly implement the interface
1819 return false;
1820 }
1821
1822 template <typename T, typename ParamType>
DoBlasGemmWithProfilingImpl(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const ParamType & alpha,const DeviceMemory<T> & a,int lda,const DeviceMemory<T> & b,int ldb,const ParamType & beta,DeviceMemory<T> * c,int ldc,blas::ProfileResult * output_profile_result)1823 bool ROCMBlas::DoBlasGemmWithProfilingImpl(
1824 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1825 uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a,
1826 int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta,
1827 DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result) {
1828 // ROCM TODO: properly implement the interface
1829 return false;
1830 }
1831
1832 template <typename InT, typename OutT, typename CompT>
DoBlasGemmWithAlgorithmImpl(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const CompT & alpha,const DeviceMemory<InT> & a,int lda,const DeviceMemory<InT> & b,int ldb,const CompT & beta,DeviceMemory<OutT> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)1833 bool ROCMBlas::DoBlasGemmWithAlgorithmImpl(
1834 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1835 uint64 n, uint64 k, const CompT &alpha, const DeviceMemory<InT> &a, int lda,
1836 const DeviceMemory<InT> &b, int ldb, const CompT &beta,
1837 DeviceMemory<OutT> *c, int ldc, blas::ComputationType computation_type,
1838 blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
1839 // ROCM TODO: properly implement the interface
1840 return false;
1841 }
1842
GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> * out_algorithms)1843 bool ROCMBlas::GetBlasGemmAlgorithms(
1844 std::vector<blas::AlgorithmType> *out_algorithms) {
1845 // ROCM TODO: properly implement the interface
1846 return true;
1847 }
1848
DoBlasGemmWithAlgorithm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<int> & alpha,const DeviceMemory<int8> & a,int lda,const DeviceMemory<int8> & b,int ldb,const HostOrDeviceScalar<int> & beta,DeviceMemory<int32> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)1849 bool ROCMBlas::DoBlasGemmWithAlgorithm(
1850 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1851 uint64 n, uint64 k, const HostOrDeviceScalar<int> &alpha,
1852 const DeviceMemory<int8> &a, int lda, const DeviceMemory<int8> &b, int ldb,
1853 const HostOrDeviceScalar<int> &beta, DeviceMemory<int32> *c, int ldc,
1854 blas::ComputationType computation_type, blas::AlgorithmType algorithm,
1855 blas::ProfileResult *output_profile_result) {
1856 LOG(ERROR)
1857 << "rocBLAS does not currently support the GEMMwithAlgorithm operation "
1858 << "for the \"int8\" datatype";
1859 return false;
1860 }
1861
DoBlasGemmWithAlgorithm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<Eigen::half> & alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,const HostOrDeviceScalar<Eigen::half> & beta,DeviceMemory<Eigen::half> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)1862 bool ROCMBlas::DoBlasGemmWithAlgorithm(
1863 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1864 uint64 n, uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha,
1865 const DeviceMemory<Eigen::half> &a, int lda,
1866 const DeviceMemory<Eigen::half> &b, int ldb,
1867 const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c,
1868 int ldc, blas::ComputationType computation_type,
1869 blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
1870 LOG(ERROR)
1871 << "rocBLAS does not currently support the GEMMwithAlgorithm operation "
1872 << "for the \"half\" datatype";
1873 return false;
1874 }
1875
DoBlasGemmWithAlgorithm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<float> & alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,const HostOrDeviceScalar<float> & beta,DeviceMemory<float> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)1876 bool ROCMBlas::DoBlasGemmWithAlgorithm(
1877 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1878 uint64 n, uint64 k, const HostOrDeviceScalar<float> &alpha,
1879 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,
1880 int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c,
1881 int ldc, blas::ComputationType computation_type,
1882 blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
1883 LOG(ERROR)
1884 << "rocBLAS does not currently support the GEMMwithAlgorithm operation "
1885 << "for the \"float\" datatype";
1886 return false;
1887 }
1888
DoBlasGemmWithAlgorithm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<double> & alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,const HostOrDeviceScalar<double> & beta,DeviceMemory<double> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)1889 bool ROCMBlas::DoBlasGemmWithAlgorithm(
1890 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1891 uint64 n, uint64 k, const HostOrDeviceScalar<double> &alpha,
1892 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,
1893 int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c,
1894 int ldc, blas::ComputationType computation_type,
1895 blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
1896 LOG(ERROR)
1897 << "rocBLAS does not currently support the GEMMwithAlgorithm operation "
1898 << "for the \"double\" datatype";
1899 return false;
1900 }
1901
DoBlasGemmWithAlgorithm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<std::complex<float>> & alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,const HostOrDeviceScalar<std::complex<float>> & beta,DeviceMemory<std::complex<float>> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)1902 bool ROCMBlas::DoBlasGemmWithAlgorithm(
1903 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1904 uint64 n, uint64 k, const HostOrDeviceScalar<std::complex<float>> &alpha,
1905 const DeviceMemory<std::complex<float>> &a, int lda,
1906 const DeviceMemory<std::complex<float>> &b, int ldb,
1907 const HostOrDeviceScalar<std::complex<float>> &beta,
1908 DeviceMemory<std::complex<float>> *c, int ldc,
1909 blas::ComputationType computation_type, blas::AlgorithmType algorithm,
1910 blas::ProfileResult *output_profile_result) {
1911 LOG(ERROR)
1912 << "rocBLAS does not currently support the GEMMwithAlgorithm operation "
1913 << "for the \"complex<float>\" datatype";
1914 return false;
1915 }
1916
DoBlasGemmWithAlgorithm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<std::complex<double>> & alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,const HostOrDeviceScalar<std::complex<double>> & beta,DeviceMemory<std::complex<double>> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)1917 bool ROCMBlas::DoBlasGemmWithAlgorithm(
1918 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1919 uint64 n, uint64 k, const HostOrDeviceScalar<std::complex<double>> &alpha,
1920 const DeviceMemory<std::complex<double>> &a, int lda,
1921 const DeviceMemory<std::complex<double>> &b, int ldb,
1922 const HostOrDeviceScalar<std::complex<double>> &beta,
1923 DeviceMemory<std::complex<double>> *c, int ldc,
1924 blas::ComputationType computation_type, blas::AlgorithmType algorithm,
1925 blas::ProfileResult *output_profile_result) {
1926 LOG(ERROR)
1927 << "rocBLAS does not currently support the GEMMwithAlgorithm operation "
1928 << "for the \"complex<double>\" datatype";
1929 return false;
1930 }
1931
1932 // This copies from source memory: raw_ptrs[i] to target memory:
1933 // device_memory_ptr at the interval of matrix_byte_size, or vice versa.
1934 // The below algorithm tries to minimize the number of memcpy by consolidating
1935 // neighboring memcpy into a single request
1936 template <typename MAPPED_T>
ReorganizeMemory(Stream * stream,DeviceMemory<MAPPED_T> * device_memory,const std::vector<MAPPED_T * > & raw_ptrs,int batch_count,uint64_t batch_stride,bool gather)1937 port::Status ReorganizeMemory(Stream *stream,
1938 DeviceMemory<MAPPED_T> *device_memory,
1939 const std::vector<MAPPED_T *> &raw_ptrs,
1940 int batch_count, uint64_t batch_stride,
1941 bool gather) {
1942 assert(batch_count > 0);
1943 char *device_memory_ptr = static_cast<char *>(device_memory->opaque());
1944 char *src_ptr = reinterpret_cast<char *>(raw_ptrs[0]);
1945 char *dst_ptr = device_memory_ptr;
1946 size_t matrix_byte_size = batch_stride * sizeof(MAPPED_T);
1947 uint64_t cur_stride_size = matrix_byte_size;
1948
1949 for (int i = 1; i < batch_count; ++i) {
1950 if (reinterpret_cast<char *>(raw_ptrs[i]) == src_ptr + cur_stride_size) {
1951 cur_stride_size += matrix_byte_size;
1952 } else {
1953 DeviceMemoryBase src_mem = DeviceMemoryBase(src_ptr, cur_stride_size);
1954 DeviceMemoryBase target_mem = DeviceMemoryBase(dst_ptr, cur_stride_size);
1955 bool a_status =
1956 gather
1957 ? stream->ThenMemcpy(&target_mem, src_mem, cur_stride_size).ok()
1958 : stream->ThenMemcpy(&src_mem, target_mem, cur_stride_size).ok();
1959 if (!a_status) {
1960 return port::Status(
1961 port::error::INTERNAL,
1962 "failed to copy device memory in ROCMBlas::DoBlasGemmBatched");
1963 }
1964 src_ptr = reinterpret_cast<char *>(raw_ptrs[i]);
1965 dst_ptr = device_memory_ptr + i * matrix_byte_size;
1966 cur_stride_size = matrix_byte_size;
1967 }
1968 }
1969
1970 DeviceMemoryBase src_mem = DeviceMemoryBase(src_ptr, cur_stride_size);
1971 DeviceMemoryBase target_mem = DeviceMemoryBase(dst_ptr, cur_stride_size);
1972 bool a_status =
1973 gather ? stream->ThenMemcpy(&target_mem, src_mem, cur_stride_size).ok()
1974 : stream->ThenMemcpy(&src_mem, target_mem, cur_stride_size).ok();
1975 if (!a_status)
1976 return port::Status(
1977 port::error::INTERNAL,
1978 "failed to copy device memory in ROCMBlas::DoBlasGemmBatched");
1979 return port::Status::OK();
1980 }
1981
1982 template <typename T>
AllocateStridedBuffer(const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type * > & raw_ptrs,int batch_count,uint64_t batch_stride,ScratchAllocator * scratch_allocator,Stream * stream,std::unique_ptr<TemporaryDeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type>> * temp_memory,DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type> * device_memory,bool copy_data,bool & reallocated)1983 port::Status ROCMBlas::AllocateStridedBuffer(
1984 const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type *>
1985 &raw_ptrs,
1986 int batch_count, uint64_t batch_stride, ScratchAllocator *scratch_allocator,
1987 Stream *stream,
1988 std::unique_ptr<TemporaryDeviceMemory<
1989 typename RocBlasTypeConversionHelper<T>::mapped_type>> *temp_memory,
1990 DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type>
1991 *device_memory,
1992 bool copy_data, bool &reallocated) {
1993 assert(device_memory != nullptr);
1994
1995 using MAPPED_T = typename RocBlasTypeConversionHelper<T>::mapped_type;
1996
1997 bool needs_allocate_strided = false;
1998 for (int i = 1; i < batch_count; ++i) {
1999 uint64_t tmp_batch_stride = raw_ptrs[i] - raw_ptrs[i - 1];
2000 if (tmp_batch_stride != batch_stride) {
2001 needs_allocate_strided = true;
2002 break;
2003 }
2004 }
2005
2006 size_t matrix_byte_size = batch_stride * sizeof(MAPPED_T);
2007 size_t matrix_batch_byte_size = matrix_byte_size * batch_count;
2008
2009 // No need to do re-allocation, take the short cut and return
2010 if (!needs_allocate_strided) {
2011 *device_memory = DeviceMemory<MAPPED_T>(
2012 DeviceMemoryBase(raw_ptrs[0], matrix_batch_byte_size));
2013 reallocated = false;
2014 return port::Status::OK();
2015 }
2016
2017 if (scratch_allocator != nullptr) {
2018 SE_ASSIGN_OR_RETURN(
2019 DeviceMemory<uint8> batch_matrix_bytes,
2020 scratch_allocator->AllocateBytes(matrix_batch_byte_size));
2021 *device_memory = DeviceMemory<MAPPED_T>(batch_matrix_bytes);
2022 } else {
2023 assert(temp_memory != nullptr);
2024 SE_ASSIGN_OR_RETURN(*temp_memory, stream->AllocateTemporaryArray<MAPPED_T>(
2025 matrix_batch_byte_size));
2026 *device_memory =
2027 DeviceMemory<MAPPED_T>(*(*temp_memory)->mutable_device_memory());
2028 }
2029
2030 reallocated = true;
2031
2032 if (copy_data)
2033 return ReorganizeMemory(stream, device_memory, raw_ptrs, batch_count,
2034 batch_stride, true);
2035 return port::Status::OK();
2036 }
2037
2038 template <typename T, typename FuncT>
DoBlasGemmBatchedInternal(FuncT rocblas_func,Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,T alpha,const port::ArraySlice<DeviceMemory<T> * > & a_ptrs_to_wrappers,int lda,const port::ArraySlice<DeviceMemory<T> * > & b_ptrs_to_wrappers,int ldb,T beta,const port::ArraySlice<DeviceMemory<T> * > & c_ptrs_to_wrappers,int ldc,int batch_count,ScratchAllocator * scratch_allocator)2039 port::Status ROCMBlas::DoBlasGemmBatchedInternal(
2040 FuncT rocblas_func, Stream *stream, blas::Transpose transa,
2041 blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha,
2042 const port::ArraySlice<DeviceMemory<T> *> &a_ptrs_to_wrappers, int lda,
2043 const port::ArraySlice<DeviceMemory<T> *> &b_ptrs_to_wrappers, int ldb,
2044 T beta, const port::ArraySlice<DeviceMemory<T> *> &c_ptrs_to_wrappers,
2045 int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
2046 using MAPPED_T = typename RocBlasTypeConversionHelper<T>::mapped_type;
2047
2048 // Sanity checks before making any further progress
2049 uint64_t batch_stride_a = 0;
2050 uint64_t batch_stride_b = 0;
2051 uint64_t batch_stride_c = 0;
2052
2053 assert(ldc >= m);
2054 batch_stride_c = ldc * n;
2055
2056 if (ROCMBlasTranspose(transa) == rocblas_operation_none) {
2057 assert(lda >= m);
2058 batch_stride_a = lda * k;
2059 } else {
2060 assert(lda >= k);
2061 batch_stride_a = lda * m;
2062 }
2063
2064 if (ROCMBlasTranspose(transb) == rocblas_operation_none) {
2065 assert(ldb >= k);
2066 batch_stride_b = ldb * n;
2067 } else {
2068 assert(ldb >= n);
2069 batch_stride_b = ldb * k;
2070 }
2071
2072 // Allocate local vectors to hold device pointers to matrices
2073 std::vector<MAPPED_T *> a_raw_ptrs, b_raw_ptrs, c_raw_ptrs;
2074 for (int i = 0; i < batch_count; ++i) {
2075 // static_cast does work when converting Eigen::half* to rocblas_half*,
2076 // hence the use of reinterpret_cast
2077 a_raw_ptrs.push_back(
2078 reinterpret_cast<MAPPED_T *>(a_ptrs_to_wrappers[i]->opaque()));
2079 b_raw_ptrs.push_back(
2080 reinterpret_cast<MAPPED_T *>(b_ptrs_to_wrappers[i]->opaque()));
2081 c_raw_ptrs.push_back(
2082 reinterpret_cast<MAPPED_T *>(c_ptrs_to_wrappers[i]->opaque()));
2083 }
2084
2085 DeviceMemory<MAPPED_T> a;
2086 // Make sure the temporary memory are in-scope before the function returns
2087 std::unique_ptr<TemporaryDeviceMemory<MAPPED_T>> a_temp;
2088 bool reallocated_a, reallocated_b, reallocated_c;
2089 port::Status a_allocation_status = AllocateStridedBuffer<T>(
2090 a_raw_ptrs, batch_count, batch_stride_a, scratch_allocator, stream,
2091 &a_temp, &a, true, reallocated_a);
2092 if (a_allocation_status != port::Status::OK()) {
2093 return a_allocation_status;
2094 }
2095
2096 DeviceMemory<MAPPED_T> b;
2097 std::unique_ptr<TemporaryDeviceMemory<MAPPED_T>> b_temp;
2098 port::Status b_allocation_status = AllocateStridedBuffer<T>(
2099 b_raw_ptrs, batch_count, batch_stride_b, scratch_allocator, stream,
2100 &b_temp, &b, true, reallocated_b);
2101 if (b_allocation_status != port::Status::OK()) {
2102 return b_allocation_status;
2103 }
2104
2105 DeviceMemory<MAPPED_T> c;
2106 std::unique_ptr<TemporaryDeviceMemory<MAPPED_T>> c_temp;
2107 port::Status c_allocation_status = AllocateStridedBuffer<T>(
2108 c_raw_ptrs, batch_count, batch_stride_c, scratch_allocator, stream,
2109 &c_temp, &c, true, reallocated_c); // can disable copy if beta=0
2110 if (c_allocation_status != port::Status::OK()) {
2111 return c_allocation_status;
2112 }
2113
2114 MAPPED_T *alpha_ptr = reinterpret_cast<MAPPED_T *>(&alpha);
2115 MAPPED_T *beta_ptr = reinterpret_cast<MAPPED_T *>(&beta);
2116
2117 bool ok;
2118 ok = DoBlasInternal(rocblas_func, stream, /* pointer_mode_host = */ true,
2119 ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m,
2120 n, k, GpuComplex(alpha_ptr), GpuMemory(a), lda,
2121 batch_stride_a, GpuMemory(b), ldb, batch_stride_b,
2122 GpuComplex(beta_ptr), GpuMemoryMutable(&c), ldc,
2123 batch_stride_c, batch_count);
2124 if (!ok)
2125 return port::Status(port::error::INTERNAL,
2126 "failed BLAS call, see log for details");
2127 if (reallocated_c)
2128 return ReorganizeMemory(stream, &c, c_raw_ptrs, batch_count, batch_stride_c,
2129 false);
2130 return port::Status::OK();
2131 }
2132
DoBlasGemmBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<Eigen::half> * > & a,int lda,const port::ArraySlice<DeviceMemory<Eigen::half> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<Eigen::half> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)2133 bool ROCMBlas::DoBlasGemmBatched(
2134 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2135 uint64 n, uint64 k, float alpha,
2136 const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
2137 const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
2138 const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
2139 int batch_count, ScratchAllocator *scratch_allocator) {
2140 blas_log("DoBlasGemmBatched");
2141 const Eigen::half alpha_half(alpha);
2142 const Eigen::half beta_half(beta);
2143
2144 port::Status status = DoBlasGemmBatchedInternal(
2145 wrap::rocblas_hgemm_strided_batched, stream, transa, transb, m, n, k,
2146 alpha_half, a, lda, b, ldb, beta_half, c, ldc, batch_count,
2147 scratch_allocator);
2148 if (!status.ok()) {
2149 LOG(ERROR) << status;
2150 }
2151
2152 return status.ok();
2153 }
2154
DoBlasGemmBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<float> * > & a_array,int lda,const port::ArraySlice<DeviceMemory<float> * > & b_array,int ldb,float beta,const port::ArraySlice<DeviceMemory<float> * > & c_array,int ldc,int batch_count,ScratchAllocator * scratch_allocator)2155 bool ROCMBlas::DoBlasGemmBatched(
2156 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2157 uint64 n, uint64 k, float alpha,
2158 const port::ArraySlice<DeviceMemory<float> *> &a_array, int lda,
2159 const port::ArraySlice<DeviceMemory<float> *> &b_array, int ldb, float beta,
2160 const port::ArraySlice<DeviceMemory<float> *> &c_array, int ldc,
2161 int batch_count, ScratchAllocator *scratch_allocator) {
2162 blas_log("DoBlasGemmBatched");
2163 port::Status status = DoBlasGemmBatchedInternal(
2164 wrap::rocblas_sgemm_strided_batched, stream, transa, transb, m, n, k,
2165 alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count,
2166 scratch_allocator);
2167 if (!status.ok()) {
2168 LOG(ERROR) << status;
2169 }
2170 return status.ok();
2171 }
2172
DoBlasGemmBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const port::ArraySlice<DeviceMemory<double> * > & a_array,int lda,const port::ArraySlice<DeviceMemory<double> * > & b_array,int ldb,double beta,const port::ArraySlice<DeviceMemory<double> * > & c_array,int ldc,int batch_count,ScratchAllocator * scratch_allocator)2173 bool ROCMBlas::DoBlasGemmBatched(
2174 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2175 uint64 n, uint64 k, double alpha,
2176 const port::ArraySlice<DeviceMemory<double> *> &a_array, int lda,
2177 const port::ArraySlice<DeviceMemory<double> *> &b_array, int ldb,
2178 double beta, const port::ArraySlice<DeviceMemory<double> *> &c_array,
2179 int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
2180 blas_log("DoBlasGemmBatched");
2181 port::Status status = DoBlasGemmBatchedInternal(
2182 wrap::rocblas_dgemm_strided_batched, stream, transa, transb, m, n, k,
2183 alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count,
2184 scratch_allocator);
2185 if (!status.ok()) {
2186 LOG(ERROR) << status;
2187 }
2188 return status.ok();
2189 }
2190
DoBlasGemmBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & a_array,int lda,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & b_array,int ldb,std::complex<float> beta,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & c_array,int ldc,int batch_count,ScratchAllocator * scratch_allocator)2191 bool ROCMBlas::DoBlasGemmBatched(
2192 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2193 uint64 n, uint64 k, std::complex<float> alpha,
2194 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a_array,
2195 int lda,
2196 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b_array,
2197 int ldb, std::complex<float> beta,
2198 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c_array,
2199 int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
2200 blas_log("DoBlasGemmBatched");
2201 port::Status status = DoBlasGemmBatchedInternal(
2202 wrap::rocblas_cgemm_strided_batched, stream, transa, transb, m, n, k,
2203 alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count,
2204 scratch_allocator);
2205 if (!status.ok()) {
2206 LOG(ERROR) << status;
2207 }
2208 return status.ok();
2209 }
2210
DoBlasGemmBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & a_array,int lda,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & b_array,int ldb,std::complex<double> beta,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & c_array,int ldc,int batch_count,ScratchAllocator * scratch_allocator)2211 bool ROCMBlas::DoBlasGemmBatched(
2212 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2213 uint64 n, uint64 k, std::complex<double> alpha,
2214 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a_array,
2215 int lda,
2216 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b_array,
2217 int ldb, std::complex<double> beta,
2218 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c_array,
2219 int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
2220 blas_log("DoBlasGemmBatched");
2221 port::Status status = DoBlasGemmBatchedInternal(
2222 wrap::rocblas_zgemm_strided_batched, stream, transa, transb, m, n, k,
2223 alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count,
2224 scratch_allocator);
2225 if (!status.ok()) {
2226 LOG(ERROR) << status;
2227 }
2228 return status.ok();
2229 }
2230
DoBlasHemm(Stream * stream,blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)2231 bool ROCMBlas::DoBlasHemm(Stream *stream, blas::Side side,
2232 blas::UpperLower uplo, uint64 m, uint64 n,
2233 std::complex<float> alpha,
2234 const DeviceMemory<std::complex<float>> &a, int lda,
2235 const DeviceMemory<std::complex<float>> &b, int ldb,
2236 std::complex<float> beta,
2237 DeviceMemory<std::complex<float>> *c, int ldc) {
2238 return DoBlasInternal(wrap::rocblas_chemm, stream,
2239 /* pointer_mode_host = */ true, ROCMBlasSide(side),
2240 ROCMBlasUpperLower(uplo), m, n, complex_cast(alpha),
2241 complex_cast(a), lda, complex_cast(b), ldb,
2242 complex_cast(beta), complex_cast(c), ldc);
2243 }
2244
DoBlasHemm(Stream * stream,blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)2245 bool ROCMBlas::DoBlasHemm(Stream *stream, blas::Side side,
2246 blas::UpperLower uplo, uint64 m, uint64 n,
2247 std::complex<double> alpha,
2248 const DeviceMemory<std::complex<double>> &a, int lda,
2249 const DeviceMemory<std::complex<double>> &b, int ldb,
2250 std::complex<double> beta,
2251 DeviceMemory<std::complex<double>> *c, int ldc) {
2252 return DoBlasInternal(wrap::rocblas_zhemm, stream,
2253 /* pointer_mode_host = */ true, ROCMBlasSide(side),
2254 ROCMBlasUpperLower(uplo), m, n, complex_cast(alpha),
2255 complex_cast(a), lda, complex_cast(b), ldb,
2256 complex_cast(beta), complex_cast(c), ldc);
2257 }
2258
DoBlasHerk(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<std::complex<float>> & a,int lda,float beta,DeviceMemory<std::complex<float>> * c,int ldc)2259 bool ROCMBlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo,
2260 blas::Transpose trans, uint64 n, uint64 k,
2261 float alpha,
2262 const DeviceMemory<std::complex<float>> &a, int lda,
2263 float beta, DeviceMemory<std::complex<float>> *c,
2264 int ldc) {
2265 return DoBlasInternal(wrap::rocblas_cherk, stream,
2266 /* pointer_mode_host = */ true,
2267 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n,
2268 k, complex_cast(alpha), complex_cast(a), lda,
2269 complex_cast(beta), complex_cast(c), ldc);
2270 }
2271
DoBlasHerk(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<std::complex<double>> & a,int lda,double beta,DeviceMemory<std::complex<double>> * c,int ldc)2272 bool ROCMBlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo,
2273 blas::Transpose trans, uint64 n, uint64 k,
2274 double alpha,
2275 const DeviceMemory<std::complex<double>> &a, int lda,
2276 double beta, DeviceMemory<std::complex<double>> *c,
2277 int ldc) {
2278 return DoBlasInternal(wrap::rocblas_zherk, stream,
2279 /* pointer_mode_host = */ true,
2280 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n,
2281 k, complex_cast(alpha), complex_cast(a), lda,
2282 complex_cast(beta), complex_cast(c), ldc);
2283 }
2284
DoBlasHer2k(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,float beta,DeviceMemory<std::complex<float>> * c,int ldc)2285 bool ROCMBlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
2286 blas::Transpose trans, uint64 n, uint64 k,
2287 std::complex<float> alpha,
2288 const DeviceMemory<std::complex<float>> &a, int lda,
2289 const DeviceMemory<std::complex<float>> &b, int ldb,
2290 float beta, DeviceMemory<std::complex<float>> *c,
2291 int ldc) {
2292 return DoBlasInternal(
2293 wrap::rocblas_cher2k, stream, /* pointer_mode_host = */ true,
2294 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k,
2295 complex_cast(alpha), complex_cast(a), lda, complex_cast(b), ldb,
2296 complex_cast(beta), complex_cast(c), ldc);
2297 }
2298
DoBlasHer2k(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,double beta,DeviceMemory<std::complex<double>> * c,int ldc)2299 bool ROCMBlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
2300 blas::Transpose trans, uint64 n, uint64 k,
2301 std::complex<double> alpha,
2302 const DeviceMemory<std::complex<double>> &a, int lda,
2303 const DeviceMemory<std::complex<double>> &b, int ldb,
2304 double beta, DeviceMemory<std::complex<double>> *c,
2305 int ldc) {
2306 return DoBlasInternal(
2307 wrap::rocblas_zher2k, stream, /* pointer_mode_host = */ true,
2308 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k,
2309 complex_cast(alpha), complex_cast(a), lda, complex_cast(b), ldb,
2310 complex_cast(beta), complex_cast(c), ldc);
2311 }
2312
DoBlasSymm(Stream * stream,blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)2313 bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side,
2314 blas::UpperLower uplo, uint64 m, uint64 n,
2315 float alpha, const DeviceMemory<float> &a, int lda,
2316 const DeviceMemory<float> &b, int ldb, float beta,
2317 DeviceMemory<float> *c, int ldc) {
2318 return DoBlasInternal(
2319 wrap::rocblas_ssymm, stream, /* pointer_mode_host = */ true,
2320 ROCMBlasSide(side), ROCMBlasUpperLower(uplo), m, n, &alpha, GpuMemory(a),
2321 lda, GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc);
2322 }
2323
DoBlasSymm(Stream * stream,blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)2324 bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side,
2325 blas::UpperLower uplo, uint64 m, uint64 n,
2326 double alpha, const DeviceMemory<double> &a, int lda,
2327 const DeviceMemory<double> &b, int ldb, double beta,
2328 DeviceMemory<double> *c, int ldc) {
2329 return DoBlasInternal(
2330 wrap::rocblas_dsymm, stream, /* pointer_mode_host = */ true,
2331 ROCMBlasSide(side), ROCMBlasUpperLower(uplo), m, n, &alpha, GpuMemory(a),
2332 lda, GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc);
2333 }
2334
DoBlasSymm(Stream * stream,blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)2335 bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side,
2336 blas::UpperLower uplo, uint64 m, uint64 n,
2337 std::complex<float> alpha,
2338 const DeviceMemory<std::complex<float>> &a, int lda,
2339 const DeviceMemory<std::complex<float>> &b, int ldb,
2340 std::complex<float> beta,
2341 DeviceMemory<std::complex<float>> *c, int ldc) {
2342 return DoBlasInternal(wrap::rocblas_csymm, stream,
2343 /* pointer_mode_host = */ true, ROCMBlasSide(side),
2344 ROCMBlasUpperLower(uplo), m, n, complex_cast(alpha),
2345 complex_cast(a), lda, complex_cast(b), ldb,
2346 complex_cast(beta), complex_cast(c), ldc);
2347 }
2348
DoBlasSymm(Stream * stream,blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)2349 bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side,
2350 blas::UpperLower uplo, uint64 m, uint64 n,
2351 std::complex<double> alpha,
2352 const DeviceMemory<std::complex<double>> &a, int lda,
2353 const DeviceMemory<std::complex<double>> &b, int ldb,
2354 std::complex<double> beta,
2355 DeviceMemory<std::complex<double>> *c, int ldc) {
2356 return DoBlasInternal(wrap::rocblas_zsymm, stream,
2357 /* pointer_mode_host = */ true, ROCMBlasSide(side),
2358 ROCMBlasUpperLower(uplo), m, n, complex_cast(alpha),
2359 complex_cast(a), lda, complex_cast(b), ldb,
2360 complex_cast(beta), complex_cast(c), ldc);
2361 }
2362
DoBlasSyrk(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,float beta,DeviceMemory<float> * c,int ldc)2363 bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
2364 blas::Transpose trans, uint64 n, uint64 k,
2365 float alpha, const DeviceMemory<float> &a, int lda,
2366 float beta, DeviceMemory<float> *c, int ldc) {
2367 return DoBlasInternal(
2368 wrap::rocblas_ssyrk, stream, /* pointer_mode_host = */ true,
2369 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k, &alpha,
2370 GpuMemory(a), lda, &beta, GpuMemoryMutable(c), ldc);
2371 }
2372
DoBlasSyrk(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,double beta,DeviceMemory<double> * c,int ldc)2373 bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
2374 blas::Transpose trans, uint64 n, uint64 k,
2375 double alpha, const DeviceMemory<double> &a, int lda,
2376 double beta, DeviceMemory<double> *c, int ldc) {
2377 return DoBlasInternal(
2378 wrap::rocblas_dsyrk, stream, /* pointer_mode_host = */ true,
2379 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k, &alpha,
2380 GpuMemory(a), lda, &beta, GpuMemoryMutable(c), ldc);
2381 }
2382
DoBlasSyrk(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)2383 bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
2384 blas::Transpose trans, uint64 n, uint64 k,
2385 std::complex<float> alpha,
2386 const DeviceMemory<std::complex<float>> &a, int lda,
2387 std::complex<float> beta,
2388 DeviceMemory<std::complex<float>> *c, int ldc) {
2389 return DoBlasInternal(wrap::rocblas_csyrk, stream,
2390 /* pointer_mode_host = */ true,
2391 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n,
2392 k, complex_cast(alpha), complex_cast(a), lda,
2393 complex_cast(beta), complex_cast(c), ldc);
2394 }
2395
DoBlasSyrk(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)2396 bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
2397 blas::Transpose trans, uint64 n, uint64 k,
2398 std::complex<double> alpha,
2399 const DeviceMemory<std::complex<double>> &a, int lda,
2400 std::complex<double> beta,
2401 DeviceMemory<std::complex<double>> *c, int ldc) {
2402 return DoBlasInternal(wrap::rocblas_zsyrk, stream,
2403 /* pointer_mode_host = */ true,
2404 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n,
2405 k, complex_cast(alpha), complex_cast(a), lda,
2406 complex_cast(beta), complex_cast(c), ldc);
2407 }
2408
DoBlasSyr2k(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)2409 bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
2410 blas::Transpose trans, uint64 n, uint64 k,
2411 float alpha, const DeviceMemory<float> &a, int lda,
2412 const DeviceMemory<float> &b, int ldb, float beta,
2413 DeviceMemory<float> *c, int ldc) {
2414 return DoBlasInternal(
2415 wrap::rocblas_ssyr2k, stream, /* pointer_mode_host = */ true,
2416 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k, &alpha,
2417 GpuMemory(a), lda, GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc);
2418 }
2419
DoBlasSyr2k(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)2420 bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
2421 blas::Transpose trans, uint64 n, uint64 k,
2422 double alpha, const DeviceMemory<double> &a, int lda,
2423 const DeviceMemory<double> &b, int ldb, double beta,
2424 DeviceMemory<double> *c, int ldc) {
2425 return DoBlasInternal(
2426 wrap::rocblas_dsyr2k, stream, /* pointer_mode_host = */ true,
2427 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k, &alpha,
2428 GpuMemory(a), lda, GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc);
2429 }
2430
DoBlasSyr2k(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)2431 bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
2432 blas::Transpose trans, uint64 n, uint64 k,
2433 std::complex<float> alpha,
2434 const DeviceMemory<std::complex<float>> &a, int lda,
2435 const DeviceMemory<std::complex<float>> &b, int ldb,
2436 std::complex<float> beta,
2437 DeviceMemory<std::complex<float>> *c, int ldc) {
2438 return DoBlasInternal(
2439 wrap::rocblas_csyr2k, stream, /* pointer_mode_host = */ true,
2440 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k,
2441 complex_cast(alpha), complex_cast(a), lda, complex_cast(b), ldb,
2442 complex_cast(beta), complex_cast(c), ldc);
2443 }
2444
DoBlasSyr2k(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)2445 bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
2446 blas::Transpose trans, uint64 n, uint64 k,
2447 std::complex<double> alpha,
2448 const DeviceMemory<std::complex<double>> &a, int lda,
2449 const DeviceMemory<std::complex<double>> &b, int ldb,
2450 std::complex<double> beta,
2451 DeviceMemory<std::complex<double>> *c, int ldc) {
2452 return DoBlasInternal(
2453 wrap::rocblas_zsyr2k, stream, /* pointer_mode_host = */ true,
2454 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k,
2455 complex_cast(alpha), complex_cast(a), lda, complex_cast(b), ldb,
2456 complex_cast(beta), complex_cast(c), ldc);
2457 }
2458
DoBlasTrmm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * b,int ldb)2459 bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side,
2460 blas::UpperLower uplo, blas::Transpose transa,
2461 blas::Diagonal diag, uint64 m, uint64 n, float alpha,
2462 const DeviceMemory<float> &a, int lda,
2463 DeviceMemory<float> *b, int ldb) {
2464 return DoBlasInternal(wrap::rocblas_strmm, stream,
2465 /* pointer_mode_host = */ true, ROCMBlasSide(side),
2466 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
2467 ROCMBlasDiagonal(diag), m, n, &alpha, GpuMemory(a), lda,
2468 GpuMemoryMutable(b), ldb);
2469 }
2470
DoBlasTrmm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * b,int ldb)2471 bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side,
2472 blas::UpperLower uplo, blas::Transpose transa,
2473 blas::Diagonal diag, uint64 m, uint64 n, double alpha,
2474 const DeviceMemory<double> &a, int lda,
2475 DeviceMemory<double> *b, int ldb) {
2476 return DoBlasInternal(wrap::rocblas_dtrmm, stream,
2477 /* pointer_mode_host = */ true, ROCMBlasSide(side),
2478 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
2479 ROCMBlasDiagonal(diag), m, n, &alpha, GpuMemory(a), lda,
2480 GpuMemoryMutable(b), ldb);
2481 }
2482
DoBlasTrmm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * b,int ldb)2483 bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side,
2484 blas::UpperLower uplo, blas::Transpose transa,
2485 blas::Diagonal diag, uint64 m, uint64 n,
2486 std::complex<float> alpha,
2487 const DeviceMemory<std::complex<float>> &a, int lda,
2488 DeviceMemory<std::complex<float>> *b, int ldb) {
2489 return DoBlasInternal(wrap::rocblas_ctrmm, stream,
2490 /* pointer_mode_host = */ true, ROCMBlasSide(side),
2491 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
2492 ROCMBlasDiagonal(diag), m, n, complex_cast(alpha),
2493 complex_cast(a), lda, complex_cast(b), ldb);
2494 }
2495
DoBlasTrmm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * b,int ldb)2496 bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side,
2497 blas::UpperLower uplo, blas::Transpose transa,
2498 blas::Diagonal diag, uint64 m, uint64 n,
2499 std::complex<double> alpha,
2500 const DeviceMemory<std::complex<double>> &a, int lda,
2501 DeviceMemory<std::complex<double>> *b, int ldb) {
2502 return DoBlasInternal(wrap::rocblas_ztrmm, stream,
2503 /* pointer_mode_host = */ true, ROCMBlasSide(side),
2504 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
2505 ROCMBlasDiagonal(diag), m, n, complex_cast(alpha),
2506 complex_cast(a), lda, complex_cast(b), ldb);
2507 }
2508
DoBlasTrsm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * b,int ldb)2509 bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
2510 blas::UpperLower uplo, blas::Transpose transa,
2511 blas::Diagonal diag, uint64 m, uint64 n, float alpha,
2512 const DeviceMemory<float> &a, int lda,
2513 DeviceMemory<float> *b, int ldb) {
2514 blas_log("DoBlasTrsm");
2515 return DoBlasInternal(wrap::rocblas_strsm, stream,
2516 /* pointer_mode_host = */ true, ROCMBlasSide(side),
2517 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
2518 ROCMBlasDiagonal(diag), m, n, &alpha, GpuMemory(a), lda,
2519 GpuMemoryMutable(b), ldb);
2520 }
2521
DoBlasTrsm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * b,int ldb)2522 bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
2523 blas::UpperLower uplo, blas::Transpose transa,
2524 blas::Diagonal diag, uint64 m, uint64 n, double alpha,
2525 const DeviceMemory<double> &a, int lda,
2526 DeviceMemory<double> *b, int ldb) {
2527 blas_log("DoBlasTrsm");
2528 return DoBlasInternal(wrap::rocblas_dtrsm, stream,
2529 /* pointer_mode_host = */ true, ROCMBlasSide(side),
2530 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
2531 ROCMBlasDiagonal(diag), m, n, &alpha, GpuMemory(a), lda,
2532 GpuMemoryMutable(b), ldb);
2533 }
2534
DoBlasTrsm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * b,int ldb)2535 bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
2536 blas::UpperLower uplo, blas::Transpose transa,
2537 blas::Diagonal diag, uint64 m, uint64 n,
2538 std::complex<float> alpha,
2539 const DeviceMemory<std::complex<float>> &a, int lda,
2540 DeviceMemory<std::complex<float>> *b, int ldb) {
2541 return DoBlasInternal(wrap::rocblas_ctrsm, stream,
2542 /* pointer_mode_host = */ true, ROCMBlasSide(side),
2543 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
2544 ROCMBlasDiagonal(diag), m, n, complex_cast(alpha),
2545 complex_cast(a), lda, complex_cast(b), ldb);
2546 }
2547
DoBlasTrsm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * b,int ldb)2548 bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
2549 blas::UpperLower uplo, blas::Transpose transa,
2550 blas::Diagonal diag, uint64 m, uint64 n,
2551 std::complex<double> alpha,
2552 const DeviceMemory<std::complex<double>> &a, int lda,
2553 DeviceMemory<std::complex<double>> *b, int ldb) {
2554 return DoBlasInternal(wrap::rocblas_ztrsm, stream,
2555 /* pointer_mode_host = */ true, ROCMBlasSide(side),
2556 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
2557 ROCMBlasDiagonal(diag), m, n, complex_cast(alpha),
2558 complex_cast(a), lda, complex_cast(b), ldb);
2559 }
2560
DoBlasGemmStridedBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,int64 stride_a,const DeviceMemory<Eigen::half> & b,int ldb,int64 stride_b,float beta,DeviceMemory<Eigen::half> * c,int ldc,int64 stride_c,int batch_count)2561 bool ROCMBlas::DoBlasGemmStridedBatched(
2562 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2563 uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
2564 int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb,
2565 int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc,
2566 int64 stride_c, int batch_count) {
2567 blas_log("DoBlasGemmStridedBatched");
2568 const Eigen::half alpha_half(alpha);
2569 const Eigen::half beta_half(beta);
2570
2571 return DoBlasInternal(
2572 wrap::rocblas_hgemm_strided_batched, stream,
2573 false, /* pointer_mode_host */
2574 ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
2575 reinterpret_cast<const rocblas_half *>(&alpha_half),
2576 reinterpret_cast<const rocblas_half *>(GpuMemory(a)), lda, stride_a,
2577 reinterpret_cast<const rocblas_half *>(GpuMemory(b)), ldb, stride_b,
2578 reinterpret_cast<const rocblas_half *>(&beta_half),
2579 reinterpret_cast<rocblas_half *>(GpuMemoryMutable(c)), ldc, stride_c,
2580 batch_count);
2581 }
2582
DoBlasGemmStridedBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,int64 stride_a,const DeviceMemory<float> & b,int ldb,int64 stride_b,float beta,DeviceMemory<float> * c,int ldc,int64 stride_c,int batch_count)2583 bool ROCMBlas::DoBlasGemmStridedBatched(
2584 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2585 uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
2586 int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
2587 float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
2588 int batch_count) {
2589 VLOG(1) << absl::StreamFormat(
2590 "doing rocBLAS SGEMM Strided Batched<float>: at=%d bt=%d m=%u n=%u "
2591 "k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
2592 "c=%p ldc=%d",
2593 static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
2594 a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
2595 return DoBlasInternal(wrap::rocblas_sgemm_strided_batched, stream,
2596 false, /* pointer_mode_host */
2597 ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m,
2598 n, k, &alpha, GpuMemory(a), lda, stride_a, GpuMemory(b),
2599 ldb, stride_b, &beta, GpuMemoryMutable(c), ldc,
2600 stride_c, batch_count);
2601 }
DoBlasGemmStridedBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,int64 stride_a,const DeviceMemory<double> & b,int ldb,int64 stride_b,double beta,DeviceMemory<double> * c,int ldc,int64 stride_c,int batch_count)2602 bool ROCMBlas::DoBlasGemmStridedBatched(
2603 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2604 uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
2605 int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
2606 double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
2607 int batch_count) {
2608 VLOG(1) << absl::StreamFormat(
2609 "doing rocBLAS SGEMM Strided Batched<double>: at=%d bt=%d m=%u n=%u "
2610 "k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
2611 "c=%p ldc=%d",
2612 static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
2613 a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
2614 return DoBlasInternal(wrap::rocblas_dgemm_strided_batched, stream,
2615 false, /* pointer_mode_host */
2616 ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m,
2617 n, k, &alpha, GpuMemory(a), lda, stride_a, GpuMemory(b),
2618 ldb, stride_b, &beta, GpuMemoryMutable(c), ldc,
2619 stride_c, batch_count);
2620 }
DoBlasGemmStridedBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,int64 stride_a,const DeviceMemory<std::complex<float>> & b,int ldb,int64 stride_b,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc,int64 stride_c,int batch_count)2621 bool ROCMBlas::DoBlasGemmStridedBatched(
2622 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2623 uint64 n, uint64 k, std::complex<float> alpha,
2624 const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
2625 const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
2626 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
2627 int64 stride_c, int batch_count) {
2628 return DoBlasInternal(wrap::rocblas_cgemm_strided_batched, stream,
2629 false, /* pointer_mode_host */
2630 ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m,
2631 n, k, complex_cast(alpha), complex_cast(a), lda,
2632 stride_a, complex_cast(b), ldb, stride_b,
2633 complex_cast(beta), complex_cast(c), ldc, stride_c,
2634 batch_count);
2635 }
DoBlasGemmStridedBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,int64 stride_a,const DeviceMemory<std::complex<double>> & b,int ldb,int64 stride_b,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc,int64 stride_c,int batch_count)2636 bool ROCMBlas::DoBlasGemmStridedBatched(
2637 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2638 uint64 n, uint64 k, std::complex<double> alpha,
2639 const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
2640 const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
2641 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
2642 int64 stride_c, int batch_count) {
2643 return DoBlasInternal(wrap::rocblas_zgemm_strided_batched, stream,
2644 false, /* pointer_mode_host */
2645 ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m,
2646 n, k, complex_cast(alpha), complex_cast(a), lda,
2647 stride_a, complex_cast(b), ldb, stride_b,
2648 complex_cast(beta), complex_cast(c), ldc, stride_c,
2649 batch_count);
2650 }
2651
GetVersion(string * version)2652 port::Status ROCMBlas::GetVersion(string *version) {
2653 return port::UnimplementedError("");
2654 }
2655
2656 port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams & p)2657 ROCMBlas::CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams &p) {
2658 return port::Status(
2659 port::error::UNIMPLEMENTED,
2660 "CreateBlasLtMatmulPlan is not supported with this version of ROCM");
2661 }
2662
2663 port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan * plan,size_t max_workspace_size,int max_algorithm_count)2664 ROCMBlas::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
2665 size_t max_workspace_size,
2666 int max_algorithm_count) {
2667 return port::Status(
2668 port::error::UNIMPLEMENTED,
2669 "GetBlasLtMatmulAlgorithms is not supported with this version of ROCM");
2670 }
2671
DoBlasLtMatmul(Stream * stream,const blas::IBlasLtMatmulPlan * plan,const HostOrDeviceScalar<void> & alpha,DeviceMemoryBase a,DeviceMemoryBase b,const HostOrDeviceScalar<void> & beta,DeviceMemoryBase c,ScratchAllocator * scratch_allocator,const blas::IBlasLtMatmulAlgorithm * algorithm,DeviceMemoryBase bias,blas::ProfileResult * output_profile_result)2672 bool ROCMBlas::DoBlasLtMatmul(
2673 Stream *stream, const blas::IBlasLtMatmulPlan *plan,
2674 const HostOrDeviceScalar<void> &alpha, DeviceMemoryBase a,
2675 DeviceMemoryBase b, const HostOrDeviceScalar<void> &beta,
2676 DeviceMemoryBase c, ScratchAllocator *scratch_allocator,
2677 const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias,
2678 blas::ProfileResult *output_profile_result) {
2679 return false;
2680 }
2681
2682 } // namespace gpu
2683
initialize_rocblas()2684 void initialize_rocblas() {
2685 auto rocBlasAlreadyRegistered = PluginRegistry::Instance()->HasFactory(
2686 rocm::kROCmPlatformId, PluginKind::kBlas, gpu::kRocBlasPlugin);
2687
2688 if (!rocBlasAlreadyRegistered) {
2689 port::Status status =
2690 PluginRegistry::Instance()
2691 ->RegisterFactory<PluginRegistry::BlasFactory>(
2692 rocm::kROCmPlatformId, gpu::kRocBlasPlugin, "rocBLAS",
2693 [](internal::StreamExecutorInterface *parent)
2694 -> blas::BlasSupport * {
2695 gpu::GpuExecutor *rocm_executor =
2696 dynamic_cast<gpu::GpuExecutor *>(parent);
2697 if (rocm_executor == nullptr) {
2698 LOG(ERROR)
2699 << "Attempting to initialize an instance of the "
2700 "rocBLAS "
2701 << "support library with a non-ROCM StreamExecutor";
2702 return nullptr;
2703 }
2704
2705 gpu::ROCMBlas *blas = new gpu::ROCMBlas(rocm_executor);
2706 if (!blas->Init()) {
2707 // Note: Init() will log a more specific error.
2708 delete blas;
2709 return nullptr;
2710 }
2711 return blas;
2712 });
2713
2714 if (!status.ok()) {
2715 LOG(ERROR) << "Unable to register rocBLAS factory: "
2716 << status.error_message();
2717 }
2718
2719 PluginRegistry::Instance()->SetDefaultFactory(
2720 rocm::kROCmPlatformId, PluginKind::kBlas, gpu::kRocBlasPlugin);
2721 }
2722 }
2723
2724 } // namespace stream_executor
2725
2726 REGISTER_MODULE_INITIALIZER(register_rocblas,
2727 { stream_executor::initialize_rocblas(); });
2728