• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "rocm/include/rocblas.h"
17 
18 #include "tensorflow/stream_executor/rocm/rocm_blas.h"
19 
20 #define EIGEN_USE_GPU
21 #include <assert.h>
22 
23 #include <complex>
24 
25 #include "absl/strings/str_cat.h"
26 #include "absl/strings/str_format.h"
27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
28 #include "tensorflow/stream_executor/device_memory.h"
29 #include "tensorflow/stream_executor/gpu/gpu_activation.h"
30 #include "tensorflow/stream_executor/gpu/gpu_executor.h"
31 #include "tensorflow/stream_executor/gpu/gpu_helpers.h"
32 #include "tensorflow/stream_executor/gpu/gpu_stream.h"
33 #include "tensorflow/stream_executor/gpu/gpu_timer.h"
34 #include "tensorflow/stream_executor/lib/env.h"
35 #include "tensorflow/stream_executor/lib/initialize.h"
36 #include "tensorflow/stream_executor/lib/status.h"
37 #include "tensorflow/stream_executor/lib/status_macros.h"
38 #include "tensorflow/stream_executor/platform/dso_loader.h"
39 #include "tensorflow/stream_executor/platform/logging.h"
40 #include "tensorflow/stream_executor/platform/port.h"
41 #include "tensorflow/stream_executor/plugin_registry.h"
42 #include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
43 #include "tensorflow/stream_executor/scratch_allocator.h"
44 #include "tensorflow/stream_executor/stream_executor.h"
45 
46 namespace stream_executor {
47 namespace gpu {
48 
49 PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kRocBlasPlugin);
50 
51 namespace wrap {
52 
53 #ifdef PLATFORM_GOOGLE
54 #define STREAM_EXECUTOR_ROCBLAS_WRAP(__name)                       \
55   struct WrapperShim__##__name {                                   \
56     static const char *kName;                                      \
57     template <typename... Args>                                    \
58     rocblas_status operator()(GpuExecutor *parent, Args... args) { \
59       gpu::ScopedActivateExecutorContext sac{parent};              \
60       return ::__name(args...);                                    \
61     }                                                              \
62   } __name;                                                        \
63   const char *WrapperShim__##__name::kName = #__name;
64 
65 #define STREAM_EXECUTOR_ROCBLAS_V2_WRAP(__name) \
66   STREAM_EXECUTOR_ROCBLAS_WRAP(__name)
67 
68 #else
69 
70 #define STREAM_EXECUTOR_ROCBLAS_WRAP(__name)                              \
71   struct DynLoadShim__##__name {                                          \
72     static const char *kName;                                             \
73     using FuncPtrT = std::add_pointer<decltype(::__name)>::type;          \
74     static void *GetDsoHandle() {                                         \
75       auto s = internal::CachedDsoLoader::GetRocblasDsoHandle();          \
76       return s.ValueOrDie();                                              \
77     }                                                                     \
78     static FuncPtrT LoadOrDie() {                                         \
79       void *f;                                                            \
80       auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \
81                                                           kName, &f);     \
82       CHECK(s.ok()) << "could not find " << kName                         \
83                     << " in rocblas DSO; dlerror: " << s.error_message(); \
84       return reinterpret_cast<FuncPtrT>(f);                               \
85     }                                                                     \
86     static FuncPtrT DynLoad() {                                           \
87       static FuncPtrT f = LoadOrDie();                                    \
88       return f;                                                           \
89     }                                                                     \
90     template <typename... Args>                                           \
91     rocblas_status operator()(GpuExecutor *parent, Args... args) {        \
92       gpu::ScopedActivateExecutorContext sac{parent};                     \
93       return DynLoad()(args...);                                          \
94     }                                                                     \
95   } __name;                                                               \
96   const char *DynLoadShim__##__name::kName = #__name;
97 
98 #define STREAM_EXECUTOR_ROCBLAS_V2_WRAP(__name) \
99   STREAM_EXECUTOR_ROCBLAS_WRAP(__name)
100 
101 #endif
102 
103 // clang-format off
104 #define ROCBLAS_BLAS_ROUTINE_EACH(__macro)  \
105   __macro(rocblas_snrm2)                    \
106   __macro(rocblas_dnrm2)                    \
107   __macro(rocblas_scnrm2)		    \
108   __macro(rocblas_dznrm2)                   \
109   __macro(rocblas_sdot)                     \
110   __macro(rocblas_ddot)                     \
111   __macro(rocblas_cdotu)                    \
112   __macro(rocblas_cdotc)		    \
113   __macro(rocblas_zdotu)		    \
114   __macro(rocblas_zdotc)		    \
115   __macro(rocblas_sscal)                    \
116   __macro(rocblas_dscal)                    \
117   __macro(rocblas_cscal)                    \
118   __macro(rocblas_csscal)		    \
119   __macro(rocblas_zscal)		    \
120   __macro(rocblas_zdscal)		    \
121   __macro(rocblas_saxpy)                    \
122   __macro(rocblas_daxpy)                    \
123   __macro(rocblas_caxpy)                    \
124   __macro(rocblas_zaxpy)		    \
125   __macro(rocblas_scopy)                    \
126   __macro(rocblas_dcopy)                    \
127   __macro(rocblas_ccopy)                    \
128   __macro(rocblas_zcopy)		    \
129   __macro(rocblas_sswap)                    \
130   __macro(rocblas_dswap)                    \
131   __macro(rocblas_cswap)                    \
132   __macro(rocblas_zswap)		    \
133   __macro(rocblas_isamax)                   \
134   __macro(rocblas_idamax)                   \
135   __macro(rocblas_icamax)                   \
136   __macro(rocblas_izamax)		    \
137   __macro(rocblas_isamin)                   \
138   __macro(rocblas_idamin)                   \
139   __macro(rocblas_icamin)                   \
140   __macro(rocblas_izamin)		    \
141   __macro(rocblas_sasum)                    \
142   __macro(rocblas_dasum)                    \
143   __macro(rocblas_scasum)                   \
144   __macro(rocblas_dzasum)		    \
145   __macro(rocblas_srot)			    \
146   __macro(rocblas_drot)			    \
147   __macro(rocblas_crot)			    \
148   __macro(rocblas_csrot)		    \
149   __macro(rocblas_zrot)			    \
150   __macro(rocblas_zdrot)		    \
151   __macro(rocblas_srotg)		    \
152   __macro(rocblas_drotg)		    \
153   __macro(rocblas_crotg)		    \
154   __macro(rocblas_zrotg)		    \
155   __macro(rocblas_srotm)		    \
156   __macro(rocblas_drotm)		    \
157   __macro(rocblas_srotmg)		    \
158   __macro(rocblas_drotmg)		    \
159   __macro(rocblas_sgemv)                    \
160   __macro(rocblas_dgemv)                    \
161   __macro(rocblas_cgemv)                    \
162   __macro(rocblas_zgemv)		    \
163   __macro(rocblas_sgbmv)		    \
164   __macro(rocblas_dgbmv)		    \
165   __macro(rocblas_cgbmv)		    \
166   __macro(rocblas_zgbmv)		    \
167   __macro(rocblas_strmv)		    \
168   __macro(rocblas_dtrmv)		    \
169   __macro(rocblas_ctrmv)		    \
170   __macro(rocblas_ztrmv)		    \
171   __macro(rocblas_stbmv)		    \
172   __macro(rocblas_dtbmv)		    \
173   __macro(rocblas_ctbmv)		    \
174   __macro(rocblas_ztbmv)		    \
175   __macro(rocblas_stpmv)		    \
176   __macro(rocblas_dtpmv)		    \
177   __macro(rocblas_ctpmv)		    \
178   __macro(rocblas_ztpmv)		    \
179   __macro(rocblas_strsv)		    \
180   __macro(rocblas_dtrsv)		    \
181   __macro(rocblas_ctrsv)		    \
182   __macro(rocblas_ztrsv)		    \
183   __macro(rocblas_stpsv)		    \
184   __macro(rocblas_dtpsv)		    \
185   __macro(rocblas_ctpsv)		    \
186   __macro(rocblas_ztpsv)		    \
187   __macro(rocblas_stbsv)		    \
188   __macro(rocblas_dtbsv)		    \
189   __macro(rocblas_ctbsv)		    \
190   __macro(rocblas_ztbsv)		    \
191   __macro(rocblas_ssymv)		    \
192   __macro(rocblas_dsymv)		    \
193   /*    __macro(rocblas_csymv)		    \
194     __macro(rocblas_zsymv)              */  \
195   __macro(rocblas_chemv)		    \
196   __macro(rocblas_zhemv)		    \
197   __macro(rocblas_ssbmv)		    \
198   __macro(rocblas_dsbmv)		    \
199   __macro(rocblas_chbmv)		    \
200   __macro(rocblas_zhbmv)		    \
201   __macro(rocblas_sspmv)		    \
202   __macro(rocblas_dspmv)		    \
203   __macro(rocblas_chpmv)		    \
204   __macro(rocblas_zhpmv)		    \
205   __macro(rocblas_sger)                     \
206   __macro(rocblas_dger)                     \
207   __macro(rocblas_cgeru)		    \
208   __macro(rocblas_cgerc)		    \
209   __macro(rocblas_zgeru)		    \
210   __macro(rocblas_zgerc)		    \
211   __macro(rocblas_ssyr)                     \
212   __macro(rocblas_dsyr)                     \
213   /*__macro(rocblas_csyr)                   \
214     __macro(rocblas_zsyr)               */  \
215   __macro(rocblas_cher)			    \
216   __macro(rocblas_zher)			    \
217   __macro(rocblas_sspr)			    \
218   __macro(rocblas_dspr)			    \
219   __macro(rocblas_chpr)			    \
220   __macro(rocblas_zhpr)			    \
221   __macro(rocblas_ssyr2)		    \
222   __macro(rocblas_dsyr2)		    \
223   /*  __macro(rocblas_csyr2)		    \
224     __macro(rocblas_zsyr2)              */  \
225   __macro(rocblas_cher2)		    \
226   __macro(rocblas_zher2)		    \
227   __macro(rocblas_sspr2)		    \
228   __macro(rocblas_dspr2)		    \
229   __macro(rocblas_chpr2)                    \
230   __macro(rocblas_zhpr2)		    \
231   __macro(rocblas_sgemm)                    \
232   __macro(rocblas_dgemm)                    \
233   __macro(rocblas_hgemm)                    \
234   __macro(rocblas_cgemm)                    \
235   __macro(rocblas_zgemm)		    \
236   __macro(rocblas_ssyrk)		    \
237   __macro(rocblas_dsyrk)		    \
238   __macro(rocblas_csyrk)		    \
239   __macro(rocblas_zsyrk)		    \
240   __macro(rocblas_cherk)		    \
241   __macro(rocblas_zherk)		    \
242   __macro(rocblas_ssyr2k)		    \
243   __macro(rocblas_dsyr2k)		    \
244   __macro(rocblas_csyr2k)		    \
245   __macro(rocblas_zsyr2k)		    \
246   __macro(rocblas_cher2k)		    \
247   __macro(rocblas_zher2k)		    \
248   /*    __macro(rocblas_ssyrkx)		    \
249     __macro(rocblas_dsyrkx)                 \
250     __macro(rocblas_csyrkx)                 \
251     __macro(rocblas_zsyrkx)                 \
252     __macro(rocblas_cherkx)                 \
253     __macro(rocblas_zherkx)             */  \
254   __macro(rocblas_ssymm)		    \
255   __macro(rocblas_dsymm)		    \
256   __macro(rocblas_csymm)		    \
257   __macro(rocblas_zsymm)		    \
258   __macro(rocblas_chemm)		    \
259   __macro(rocblas_zhemm)		    \
260   __macro(rocblas_strsm)                    \
261   __macro(rocblas_dtrsm)                    \
262   __macro(rocblas_ctrsm)                    \
263   __macro(rocblas_ztrsm)		    \
264   __macro(rocblas_strmm)		    \
265   __macro(rocblas_dtrmm)		    \
266   __macro(rocblas_ctrmm)		    \
267   __macro(rocblas_ztrmm)		    \
268   __macro(rocblas_sgeam)                    \
269   __macro(rocblas_dgeam)                    \
270   /*__macro(rocblas_cgeam)                  \
271     __macro(rocblas_zgeam)                  \
272     __macro(rocblas_sdgmm)                  \
273     __macro(rocblas_ddgmm)                  \
274     __macro(rocblas_cdgmm)                  \
275     __macro(rocblas_zdgmm) */
276 // clang-format on
277 
278 STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_create_handle)
279 STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_destroy_handle)
280 STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_set_stream)
281 // STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_set_pointer_mode)
282 // STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_get_pointer_mode)
283 // STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_sgemm_batched)
284 STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_hgemm_strided_batched)
285 STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_sgemm_strided_batched)
286 // STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_dgemm_batched)
287 STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_dgemm_strided_batched)
288 STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_cgemm_strided_batched)
289 STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_zgemm_strided_batched)
290 // STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_cgemm_batched)
291 // STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_zgemm_batched)
292 ROCBLAS_BLAS_ROUTINE_EACH(STREAM_EXECUTOR_ROCBLAS_V2_WRAP)
293 
294 }  // namespace wrap
295 
296 template <class T>
complex_cast(const DeviceMemory<T> & a)297 const typename RocBlasTypeConversionHelper<T>::mapped_type *complex_cast(
298     const DeviceMemory<T> &a) {
299   return reinterpret_cast<
300       const typename RocBlasTypeConversionHelper<T>::mapped_type *>(
301       GpuMemory(a));
302 }
303 template <class T>
complex_cast(const T & a)304 const typename RocBlasTypeConversionHelper<T>::mapped_type *complex_cast(
305     const T &a) {
306   return reinterpret_cast<
307       const typename RocBlasTypeConversionHelper<T>::mapped_type *>(&a);
308 }
309 template <class T>
complex_cast(DeviceMemory<T> * a)310 typename RocBlasTypeConversionHelper<T>::mapped_type *complex_cast(
311     DeviceMemory<T> *a) {
312   return reinterpret_cast<
313       typename RocBlasTypeConversionHelper<T>::mapped_type *>(
314       GpuMemoryMutable(a));
315 }
316 
blas_log(const char * c)317 static void blas_log(const char *c) {}
318 
ToString(rocblas_status status)319 static string ToString(rocblas_status status) {
320   switch (status) {
321     case rocblas_status_success:
322       return "rocblas_status_success";
323     case rocblas_status_invalid_handle:
324       return "rocblas_status_invalid_handle";
325     case rocblas_status_not_implemented:
326       return "rocblas_status_not_implemented";
327     case rocblas_status_invalid_pointer:
328       return "rocblas_status_invalid_pointer";
329     case rocblas_status_invalid_size:
330       return "rocblas_status_invalid_size";
331     case rocblas_status_memory_error:
332       return "rocblas_status_memory_error";
333     case rocblas_status_internal_error:
334       return "rocblas_status_internal_error";
335     default:
336       return absl::StrCat("<invalid rocBLAS status: ", status, ">");
337   }
338 }
339 
Init()340 bool ROCMBlas::Init() {
341   rocblas_status ret = wrap::rocblas_create_handle(parent_, &blas_);
342   if (ret != rocblas_status_success) {
343     LOG(ERROR) << "failed to create rocBLAS handle: " << ToString(ret);
344     return false;
345   }
346 
347   return true;
348 }
349 
ROCMBlas(gpu::GpuExecutor * parent)350 ROCMBlas::ROCMBlas(gpu::GpuExecutor *parent)
351     : parent_(CHECK_NOTNULL(parent)), blas_(nullptr) {}
352 
~ROCMBlas()353 ROCMBlas::~ROCMBlas() {
354   if (blas_ != nullptr) {
355     wrap::rocblas_destroy_handle(parent_, blas_);
356   }
357 }
358 
SetStream(Stream * stream)359 bool ROCMBlas::SetStream(Stream *stream) {
360   CHECK(stream != nullptr);
361   CHECK(AsGpuStreamValue(stream) != nullptr);
362   CHECK(blas_ != nullptr);
363   rocblas_status ret =
364       wrap::rocblas_set_stream(parent_, blas_, AsGpuStreamValue(stream));
365   if (ret != rocblas_status_success) {
366     LOG(ERROR) << "failed to set stream for rocBLAS calls: " << ToString(ret);
367     return false;
368   }
369 
370   return true;
371 }
372 
373 namespace {
374 
375 // Helper functions transforming blas arguments into rocBLAS arguments.
376 
ROCMBlasTranspose(blas::Transpose trans)377 rocblas_operation ROCMBlasTranspose(blas::Transpose trans) {
378   switch (trans) {
379     case blas::Transpose::kNoTranspose:
380       return rocblas_operation_none;
381     case blas::Transpose::kTranspose:
382       return rocblas_operation_transpose;
383     case blas::Transpose::kConjugateTranspose:
384       return rocblas_operation_conjugate_transpose;
385     default:
386       LOG(FATAL) << "Invalid value of blas::Transpose.";
387   }
388 }
389 
ROCMBlasUpperLower(blas::UpperLower uplo)390 rocblas_fill ROCMBlasUpperLower(blas::UpperLower uplo) {
391   switch (uplo) {
392     case blas::UpperLower::kUpper:
393       return rocblas_fill_upper;
394     case blas::UpperLower::kLower:
395       return rocblas_fill_lower;
396     default:
397       LOG(FATAL) << "Invalid value of blas::UpperLower.";
398   }
399 }
400 
ROCMBlasDiagonal(blas::Diagonal diag)401 rocblas_diagonal ROCMBlasDiagonal(blas::Diagonal diag) {
402   switch (diag) {
403     case blas::Diagonal::kUnit:
404       return rocblas_diagonal_unit;
405     case blas::Diagonal::kNonUnit:
406       return rocblas_diagonal_non_unit;
407     default:
408       LOG(FATAL) << "Invalid value of blas::Diagonal.";
409   }
410 }
411 
ROCMBlasSide(blas::Side side)412 rocblas_side ROCMBlasSide(blas::Side side) {
413   switch (side) {
414     case blas::Side::kLeft:
415       return rocblas_side_left;
416     case blas::Side::kRight:
417       return rocblas_side_right;
418     default:
419       LOG(FATAL) << "Invalid value of blas::Side.";
420   }
421 }
422 
423 }  // namespace
424 
425 template <typename FuncT, typename... Args>
DoBlasInternalImpl(FuncT rocblas_func,Stream * stream,bool pointer_mode_host,bool err_on_failure,Args...args)426 bool ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream,
427                                   bool pointer_mode_host, bool err_on_failure,
428                                   Args... args) {
429   absl::MutexLock lock{&mu_};
430 
431   CHECK(blas_ != nullptr);
432   if (!SetStream(stream)) {
433     return false;
434   }
435 
436   rocblas_status ret = rocblas_func(parent_, blas_, args...);
437   if (err_on_failure && ret != rocblas_status_success) {
438     LOG(ERROR) << "failed to run ROCBLAS routine " << rocblas_func.kName << ": "
439                << ToString(ret);
440   }
441   return ret == rocblas_status_success;
442 }
443 
DoBlasAsum(Stream * stream,uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * result)444 bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count,
445                           const DeviceMemory<float> &x, int incx,
446                           DeviceMemory<float> *result) {
447   return DoBlasInternal(wrap::rocblas_sasum, stream,
448                         /* pointer_mode_host = */ false, elem_count,
449                         GpuMemory(x), incx, GpuMemoryMutable(result));
450 }
451 
DoBlasAsum(Stream * stream,uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * result)452 bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count,
453                           const DeviceMemory<double> &x, int incx,
454                           DeviceMemory<double> *result) {
455   return DoBlasInternal(wrap::rocblas_dasum, stream,
456                         /* pointer_mode_host = */ false, elem_count,
457                         GpuMemory(x), incx, GpuMemoryMutable(result));
458 }
459 
DoBlasAsum(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<float> * result)460 bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count,
461                           const DeviceMemory<std::complex<float>> &x, int incx,
462                           DeviceMemory<float> *result) {
463   return DoBlasInternal(wrap::rocblas_scasum, stream,
464                         /* pointer_mode_host = */ false, elem_count,
465                         complex_cast(x), incx, GpuMemoryMutable(result));
466 }
467 
DoBlasAsum(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<double> * result)468 bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count,
469                           const DeviceMemory<std::complex<double>> &x, int incx,
470                           DeviceMemory<double> *result) {
471   return DoBlasInternal(wrap::rocblas_dzasum, stream,
472                         /* pointer_mode_host = */ false, elem_count,
473                         complex_cast(x), incx, GpuMemoryMutable(result));
474 }
475 
DoBlasAxpy(Stream * stream,uint64 elem_count,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * y,int incy)476 bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha,
477                           const DeviceMemory<float> &x, int incx,
478                           DeviceMemory<float> *y, int incy) {
479   blas_log("DoBlasAxpy");
480   return DoBlasInternal(wrap::rocblas_saxpy, stream,
481                         /* pointer_mode_host = */ true, elem_count, &alpha,
482                         GpuMemory(x), incx, GpuMemoryMutable(y), incy);
483 }
484 
DoBlasAxpy(Stream * stream,uint64 elem_count,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * y,int incy)485 bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha,
486                           const DeviceMemory<double> &x, int incx,
487                           DeviceMemory<double> *y, int incy) {
488   blas_log("DoBlasAxpy");
489   return DoBlasInternal(wrap::rocblas_daxpy, stream,
490                         /* pointer_mode_host = */ true, elem_count, &alpha,
491                         GpuMemory(x), incx, GpuMemoryMutable(y), incy);
492 }
493 
DoBlasAxpy(Stream * stream,uint64 elem_count,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * y,int incy)494 bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count,
495                           std::complex<float> alpha,
496                           const DeviceMemory<std::complex<float>> &x, int incx,
497                           DeviceMemory<std::complex<float>> *y, int incy) {
498   return DoBlasInternal(
499       wrap::rocblas_caxpy, stream, /* pointer_mode_host = */ true, elem_count,
500       complex_cast(alpha), complex_cast(x), incx, complex_cast(y), incy);
501 }
502 
DoBlasAxpy(Stream * stream,uint64 elem_count,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * y,int incy)503 bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count,
504                           std::complex<double> alpha,
505                           const DeviceMemory<std::complex<double>> &x, int incx,
506                           DeviceMemory<std::complex<double>> *y, int incy) {
507   return DoBlasInternal(
508       wrap::rocblas_zaxpy, stream, /* pointer_mode_host = */ true, elem_count,
509       complex_cast(alpha), complex_cast(x), incx, complex_cast(y), incy);
510 }
511 
DoBlasCopy(Stream * stream,uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * y,int incy)512 bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count,
513                           const DeviceMemory<float> &x, int incx,
514                           DeviceMemory<float> *y, int incy) {
515   return DoBlasInternal(wrap::rocblas_scopy, stream,
516                         /* pointer_mode_host = */ true, elem_count,
517                         GpuMemory(x), incx, GpuMemoryMutable(y), incy);
518 }
519 
DoBlasCopy(Stream * stream,uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * y,int incy)520 bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count,
521                           const DeviceMemory<double> &x, int incx,
522                           DeviceMemory<double> *y, int incy) {
523   return DoBlasInternal(wrap::rocblas_dcopy, stream,
524                         /* pointer_mode_host = */ true, elem_count,
525                         GpuMemory(x), incx, GpuMemoryMutable(y), incy);
526 }
527 
DoBlasCopy(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * y,int incy)528 bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count,
529                           const DeviceMemory<std::complex<float>> &x, int incx,
530                           DeviceMemory<std::complex<float>> *y, int incy) {
531   return DoBlasInternal(wrap::rocblas_ccopy, stream,
532                         /* pointer_mode_host = */ true, elem_count,
533                         complex_cast(x), incx, complex_cast(y), incy);
534 }
535 
DoBlasCopy(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * y,int incy)536 bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count,
537                           const DeviceMemory<std::complex<double>> &x, int incx,
538                           DeviceMemory<std::complex<double>> *y, int incy) {
539   return DoBlasInternal(wrap::rocblas_zcopy, stream,
540                         /* pointer_mode_host = */ true, elem_count,
541                         complex_cast(x), incx, complex_cast(y), incy);
542 }
543 
DoBlasDot(Stream * stream,uint64 elem_count,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * result)544 bool ROCMBlas::DoBlasDot(Stream *stream, uint64 elem_count,
545                          const DeviceMemory<float> &x, int incx,
546                          const DeviceMemory<float> &y, int incy,
547                          DeviceMemory<float> *result) {
548   blas_log("DoBlasDot");
549   return DoBlasInternal(
550       wrap::rocblas_sdot, stream, /* pointer_mode_host = */ false, elem_count,
551       GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(result));
552 }
553 
DoBlasDot(Stream * stream,uint64 elem_count,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * result)554 bool ROCMBlas::DoBlasDot(Stream *stream, uint64 elem_count,
555                          const DeviceMemory<double> &x, int incx,
556                          const DeviceMemory<double> &y, int incy,
557                          DeviceMemory<double> *result) {
558   blas_log("DoBlasDot");
559   return DoBlasInternal(
560       wrap::rocblas_ddot, stream, /* pointer_mode_host = */ false, elem_count,
561       GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(result));
562 }
563 
DoBlasDotc(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * result)564 bool ROCMBlas::DoBlasDotc(Stream *stream, uint64 elem_count,
565                           const DeviceMemory<std::complex<float>> &x, int incx,
566                           const DeviceMemory<std::complex<float>> &y, int incy,
567                           DeviceMemory<std::complex<float>> *result) {
568   return DoBlasInternal(
569       wrap::rocblas_cdotc, stream, /* pointer_mode_host = */ false, elem_count,
570       complex_cast(x), incx, complex_cast(y), incy, complex_cast(result));
571 }
572 
DoBlasDotc(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * result)573 bool ROCMBlas::DoBlasDotc(Stream *stream, uint64 elem_count,
574                           const DeviceMemory<std::complex<double>> &x, int incx,
575                           const DeviceMemory<std::complex<double>> &y, int incy,
576                           DeviceMemory<std::complex<double>> *result) {
577   return DoBlasInternal(
578       wrap::rocblas_zdotc, stream, /* pointer_mode_host = */ false, elem_count,
579       complex_cast(x), incx, complex_cast(y), incy, complex_cast(result));
580 }
581 
DoBlasDotu(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * result)582 bool ROCMBlas::DoBlasDotu(Stream *stream, uint64 elem_count,
583                           const DeviceMemory<std::complex<float>> &x, int incx,
584                           const DeviceMemory<std::complex<float>> &y, int incy,
585                           DeviceMemory<std::complex<float>> *result) {
586   return DoBlasInternal(
587       wrap::rocblas_cdotu, stream, /* pointer_mode_host = */ false, elem_count,
588       complex_cast(x), incx, complex_cast(y), incy, complex_cast(result));
589 }
590 
DoBlasDotu(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * result)591 bool ROCMBlas::DoBlasDotu(Stream *stream, uint64 elem_count,
592                           const DeviceMemory<std::complex<double>> &x, int incx,
593                           const DeviceMemory<std::complex<double>> &y, int incy,
594                           DeviceMemory<std::complex<double>> *result) {
595   return DoBlasInternal(
596       wrap::rocblas_zdotu, stream, /* pointer_mode_host = */ false, elem_count,
597       complex_cast(x), incx, complex_cast(y), incy, complex_cast(result));
598 }
599 
DoBlasNrm2(Stream * stream,uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * result)600 bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
601                           const DeviceMemory<float> &x, int incx,
602                           DeviceMemory<float> *result) {
603   return DoBlasInternal(wrap::rocblas_snrm2, stream,
604                         /* pointer_mode_host = */ false, elem_count,
605                         GpuMemory(x), incx, GpuMemoryMutable(result));
606 }
607 
DoBlasNrm2(Stream * stream,uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * result)608 bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
609                           const DeviceMemory<double> &x, int incx,
610                           DeviceMemory<double> *result) {
611   return DoBlasInternal(wrap::rocblas_dnrm2, stream,
612                         /* pointer_mode_host = */ false, elem_count,
613                         GpuMemory(x), incx, GpuMemoryMutable(result));
614 }
615 
DoBlasNrm2(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<float> * result)616 bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
617                           const DeviceMemory<std::complex<float>> &x, int incx,
618                           DeviceMemory<float> *result) {
619   return DoBlasInternal(wrap::rocblas_scnrm2, stream,
620                         /* pointer_mode_host = */ false, elem_count,
621                         complex_cast(x), incx, GpuMemoryMutable(result));
622 }
623 
DoBlasNrm2(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<double> * result)624 bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
625                           const DeviceMemory<std::complex<double>> &x, int incx,
626                           DeviceMemory<double> *result) {
627   return DoBlasInternal(wrap::rocblas_dznrm2, stream,
628                         /* pointer_mode_host = */ false, elem_count,
629                         complex_cast(x), incx, GpuMemoryMutable(result));
630 }
631 
DoBlasRot(Stream * stream,uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy,float c,float s)632 bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count,
633                          DeviceMemory<float> *x, int incx,
634                          DeviceMemory<float> *y, int incy, float c, float s) {
635   return DoBlasInternal(
636       wrap::rocblas_srot, stream, /* pointer_mode_host = */ true, elem_count,
637       GpuMemoryMutable(x), incx, GpuMemoryMutable(y), incy, &c, &s);
638 }
639 
DoBlasRot(Stream * stream,uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy,double c,double s)640 bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count,
641                          DeviceMemory<double> *x, int incx,
642                          DeviceMemory<double> *y, int incy, double c,
643                          double s) {
644   return DoBlasInternal(
645       wrap::rocblas_drot, stream, /* pointer_mode_host = */ true, elem_count,
646       GpuMemoryMutable(x), incx, GpuMemoryMutable(y), incy, &c, &s);
647 }
648 
DoBlasRot(Stream * stream,uint64 elem_count,DeviceMemory<std::complex<float>> * x,int incx,DeviceMemory<std::complex<float>> * y,int incy,float c,float s)649 bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count,
650                          DeviceMemory<std::complex<float>> *x, int incx,
651                          DeviceMemory<std::complex<float>> *y, int incy,
652                          float c, float s) {
653   return DoBlasInternal(wrap::rocblas_csrot, stream,
654                         /* pointer_mode_host = */ true, elem_count,
655                         complex_cast(x), incx, complex_cast(y), incy, &c, &s);
656 }
657 
DoBlasRot(Stream * stream,uint64 elem_count,DeviceMemory<std::complex<double>> * x,int incx,DeviceMemory<std::complex<double>> * y,int incy,double c,double s)658 bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count,
659                          DeviceMemory<std::complex<double>> *x, int incx,
660                          DeviceMemory<std::complex<double>> *y, int incy,
661                          double c, double s) {
662   return DoBlasInternal(wrap::rocblas_zdrot, stream,
663                         /* pointer_mode_host = */ true, elem_count,
664                         complex_cast(x), incx, complex_cast(y), incy, &c, &s);
665 }
666 
DoBlasRotg(Stream * stream,DeviceMemory<float> * a,DeviceMemory<float> * b,DeviceMemory<float> * c,DeviceMemory<float> * s)667 bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory<float> *a,
668                           DeviceMemory<float> *b, DeviceMemory<float> *c,
669                           DeviceMemory<float> *s) {
670   return DoBlasInternal(wrap::rocblas_srotg, stream,
671                         /* pointer_mode_host = */ false, GpuMemoryMutable(a),
672                         GpuMemoryMutable(b), GpuMemoryMutable(c),
673                         GpuMemoryMutable(s));
674 }
675 
DoBlasRotg(Stream * stream,DeviceMemory<double> * a,DeviceMemory<double> * b,DeviceMemory<double> * c,DeviceMemory<double> * s)676 bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory<double> *a,
677                           DeviceMemory<double> *b, DeviceMemory<double> *c,
678                           DeviceMemory<double> *s) {
679   return DoBlasInternal(wrap::rocblas_drotg, stream,
680                         /* pointer_mode_host = */ false, GpuMemoryMutable(a),
681                         GpuMemoryMutable(b), GpuMemoryMutable(c),
682                         GpuMemoryMutable(s));
683 }
684 
DoBlasRotg(Stream * stream,DeviceMemory<std::complex<float>> * a,DeviceMemory<std::complex<float>> * b,DeviceMemory<float> * c,DeviceMemory<std::complex<float>> * s)685 bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory<std::complex<float>> *a,
686                           DeviceMemory<std::complex<float>> *b,
687                           DeviceMemory<float> *c,
688                           DeviceMemory<std::complex<float>> *s) {
689   return DoBlasInternal(wrap::rocblas_crotg, stream,
690                         /* pointer_mode_host = */ false, complex_cast(a),
691                         complex_cast(b), GpuMemoryMutable(c), complex_cast(s));
692 }
693 
DoBlasRotg(Stream * stream,DeviceMemory<std::complex<double>> * a,DeviceMemory<std::complex<double>> * b,DeviceMemory<double> * c,DeviceMemory<std::complex<double>> * s)694 bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory<std::complex<double>> *a,
695                           DeviceMemory<std::complex<double>> *b,
696                           DeviceMemory<double> *c,
697                           DeviceMemory<std::complex<double>> *s) {
698   return DoBlasInternal(wrap::rocblas_zrotg, stream,
699                         /* pointer_mode_host = */ false, complex_cast(a),
700                         complex_cast(b), GpuMemoryMutable(c), complex_cast(s));
701 }
702 
DoBlasRotm(Stream * stream,uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy,const DeviceMemory<float> & param)703 bool ROCMBlas::DoBlasRotm(Stream *stream, uint64 elem_count,
704                           DeviceMemory<float> *x, int incx,
705                           DeviceMemory<float> *y, int incy,
706                           const DeviceMemory<float> &param) {
707   return DoBlasInternal(
708       wrap::rocblas_srotm, stream, /* pointer_mode_host = */ false, elem_count,
709       GpuMemoryMutable(x), incx, GpuMemoryMutable(y), incy, GpuMemory(param));
710 }
711 
DoBlasRotm(Stream * stream,uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy,const DeviceMemory<double> & param)712 bool ROCMBlas::DoBlasRotm(Stream *stream, uint64 elem_count,
713                           DeviceMemory<double> *x, int incx,
714                           DeviceMemory<double> *y, int incy,
715                           const DeviceMemory<double> &param) {
716   return DoBlasInternal(
717       wrap::rocblas_drotm, stream, /* pointer_mode_host = */ false, elem_count,
718       GpuMemoryMutable(x), incx, GpuMemoryMutable(y), incy, GpuMemory(param));
719 }
720 
DoBlasRotmg(Stream * stream,DeviceMemory<float> * d1,DeviceMemory<float> * d2,DeviceMemory<float> * x1,const DeviceMemory<float> & y1,DeviceMemory<float> * param)721 bool ROCMBlas::DoBlasRotmg(Stream *stream, DeviceMemory<float> *d1,
722                            DeviceMemory<float> *d2, DeviceMemory<float> *x1,
723                            const DeviceMemory<float> &y1,
724                            DeviceMemory<float> *param) {
725   return DoBlasInternal(wrap::rocblas_srotmg, stream,
726                         /* pointer_mode_host = */ false, GpuMemoryMutable(d1),
727                         GpuMemoryMutable(d2), GpuMemoryMutable(x1),
728                         GpuMemory(y1), GpuMemoryMutable(param));
729 }
730 
DoBlasRotmg(Stream * stream,DeviceMemory<double> * d1,DeviceMemory<double> * d2,DeviceMemory<double> * x1,const DeviceMemory<double> & y1,DeviceMemory<double> * param)731 bool ROCMBlas::DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1,
732                            DeviceMemory<double> *d2, DeviceMemory<double> *x1,
733                            const DeviceMemory<double> &y1,
734                            DeviceMemory<double> *param) {
735   return DoBlasInternal(wrap::rocblas_drotmg, stream,
736                         /* pointer_mode_host = */ false, GpuMemoryMutable(d1),
737                         GpuMemoryMutable(d2), GpuMemoryMutable(x1),
738                         GpuMemory(y1), GpuMemoryMutable(param));
739 }
740 
DoBlasScal(Stream * stream,uint64 elem_count,float alpha,DeviceMemory<float> * x,int incx)741 bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
742                           DeviceMemory<float> *x, int incx) {
743   blas_log("DoBlasScal<float>");
744   return DoBlasInternal(wrap::rocblas_sscal, stream,
745                         /* pointer_mode_host = */ true, elem_count, &alpha,
746                         GpuMemoryMutable(x), incx);
747 }
748 
DoBlasScal(Stream * stream,uint64 elem_count,double alpha,DeviceMemory<double> * x,int incx)749 bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
750                           DeviceMemory<double> *x, int incx) {
751   return DoBlasInternal(wrap::rocblas_dscal, stream,
752                         /* pointer_mode_host = */ true, elem_count, &alpha,
753                         GpuMemoryMutable(x), incx);
754 }
755 
DoBlasScal(Stream * stream,uint64 elem_count,float alpha,DeviceMemory<std::complex<float>> * x,int incx)756 bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
757                           DeviceMemory<std::complex<float>> *x, int incx) {
758   return DoBlasInternal(wrap::rocblas_csscal, stream,
759                         /* pointer_mode_host = */ true, elem_count, &alpha,
760                         complex_cast(x), incx);
761 }
762 
DoBlasScal(Stream * stream,uint64 elem_count,double alpha,DeviceMemory<std::complex<double>> * x,int incx)763 bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
764                           DeviceMemory<std::complex<double>> *x, int incx) {
765   return DoBlasInternal(wrap::rocblas_zdscal, stream,
766                         /* pointer_mode_host = */ true, elem_count, &alpha,
767                         complex_cast(x), incx);
768 }
769 
DoBlasScal(Stream * stream,uint64 elem_count,std::complex<float> alpha,DeviceMemory<std::complex<float>> * x,int incx)770 bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count,
771                           std::complex<float> alpha,
772                           DeviceMemory<std::complex<float>> *x, int incx) {
773   return DoBlasInternal(wrap::rocblas_cscal, stream,
774                         /* pointer_mode_host = */ true, elem_count,
775                         complex_cast(alpha), complex_cast(x), incx);
776 }
777 
DoBlasScal(Stream * stream,uint64 elem_count,std::complex<double> alpha,DeviceMemory<std::complex<double>> * x,int incx)778 bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count,
779                           std::complex<double> alpha,
780                           DeviceMemory<std::complex<double>> *x, int incx) {
781   return DoBlasInternal(wrap::rocblas_zscal, stream,
782                         /* pointer_mode_host = */ true, elem_count,
783                         complex_cast(alpha), complex_cast(x), incx);
784 }
785 
DoBlasSwap(Stream * stream,uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy)786 bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count,
787                           DeviceMemory<float> *x, int incx,
788                           DeviceMemory<float> *y, int incy) {
789   return DoBlasInternal(wrap::rocblas_sswap, stream,
790                         /* pointer_mode_host = */ true, elem_count,
791                         GpuMemoryMutable(x), incx, GpuMemoryMutable(y), incy);
792 }
793 
DoBlasSwap(Stream * stream,uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy)794 bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count,
795                           DeviceMemory<double> *x, int incx,
796                           DeviceMemory<double> *y, int incy) {
797   return DoBlasInternal(wrap::rocblas_dswap, stream,
798                         /* pointer_mode_host = */ true, elem_count,
799                         GpuMemoryMutable(x), incx, GpuMemoryMutable(y), incy);
800 }
801 
DoBlasSwap(Stream * stream,uint64 elem_count,DeviceMemory<std::complex<float>> * x,int incx,DeviceMemory<std::complex<float>> * y,int incy)802 bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count,
803                           DeviceMemory<std::complex<float>> *x, int incx,
804                           DeviceMemory<std::complex<float>> *y, int incy) {
805   return DoBlasInternal(wrap::rocblas_cswap, stream,
806                         /* pointer_mode_host = */ true, elem_count,
807                         complex_cast(x), incx, complex_cast(y), incy);
808 }
809 
DoBlasSwap(Stream * stream,uint64 elem_count,DeviceMemory<std::complex<double>> * x,int incx,DeviceMemory<std::complex<double>> * y,int incy)810 bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count,
811                           DeviceMemory<std::complex<double>> *x, int incx,
812                           DeviceMemory<std::complex<double>> *y, int incy) {
813   return DoBlasInternal(wrap::rocblas_zswap, stream,
814                         /* pointer_mode_host = */ true, elem_count,
815                         complex_cast(x), incx, complex_cast(y), incy);
816 }
817 
DoBlasIamax(Stream * stream,uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<int> * result)818 bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count,
819                            const DeviceMemory<float> &x, int incx,
820                            DeviceMemory<int> *result) {
821   return DoBlasInternal(wrap::rocblas_isamax, stream,
822                         /* pointer_mode_host = */ false, elem_count,
823                         GpuMemory(x), incx, GpuMemoryMutable(result));
824 }
825 
DoBlasIamax(Stream * stream,uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<int> * result)826 bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count,
827                            const DeviceMemory<double> &x, int incx,
828                            DeviceMemory<int> *result) {
829   return DoBlasInternal(wrap::rocblas_idamax, stream,
830                         /* pointer_mode_host = */ false, elem_count,
831                         GpuMemory(x), incx, GpuMemoryMutable(result));
832 }
833 
DoBlasIamax(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<int> * result)834 bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count,
835                            const DeviceMemory<std::complex<float>> &x, int incx,
836                            DeviceMemory<int> *result) {
837   return DoBlasInternal(wrap::rocblas_icamax, stream,
838                         /* pointer_mode_host = */ false, elem_count,
839                         complex_cast(x), incx, GpuMemoryMutable(result));
840 }
841 
DoBlasIamax(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<int> * result)842 bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count,
843                            const DeviceMemory<std::complex<double>> &x,
844                            int incx, DeviceMemory<int> *result) {
845   return DoBlasInternal(wrap::rocblas_izamax, stream,
846                         /* pointer_mode_host = */ false, elem_count,
847                         complex_cast(x), incx, GpuMemoryMutable(result));
848 }
849 
DoBlasIamin(Stream * stream,uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<int> * result)850 bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count,
851                            const DeviceMemory<float> &x, int incx,
852                            DeviceMemory<int> *result) {
853   return DoBlasInternal(wrap::rocblas_isamin, stream,
854                         /* pointer_mode_host = */ false, elem_count,
855                         GpuMemory(x), incx, GpuMemoryMutable(result));
856 }
857 
DoBlasIamin(Stream * stream,uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<int> * result)858 bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count,
859                            const DeviceMemory<double> &x, int incx,
860                            DeviceMemory<int> *result) {
861   return DoBlasInternal(wrap::rocblas_idamin, stream,
862                         /* pointer_mode_host = */ false, elem_count,
863                         GpuMemory(x), incx, GpuMemoryMutable(result));
864 }
865 
DoBlasIamin(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<int> * result)866 bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count,
867                            const DeviceMemory<std::complex<float>> &x, int incx,
868                            DeviceMemory<int> *result) {
869   return DoBlasInternal(wrap::rocblas_icamin, stream,
870                         /* pointer_mode_host = */ false, elem_count,
871                         complex_cast(x), incx, GpuMemoryMutable(result));
872 }
873 
DoBlasIamin(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<int> * result)874 bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count,
875                            const DeviceMemory<std::complex<double>> &x,
876                            int incx, DeviceMemory<int> *result) {
877   return DoBlasInternal(wrap::rocblas_izamin, stream,
878                         /* pointer_mode_host = */ false, elem_count,
879                         complex_cast(x), incx, GpuMemoryMutable(result));
880 }
881 
DoBlasGbmv(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)882 bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
883                           uint64 n, uint64 kl, uint64 ku, float alpha,
884                           const DeviceMemory<float> &a, int lda,
885                           const DeviceMemory<float> &x, int incx, float beta,
886                           DeviceMemory<float> *y, int incy) {
887   return DoBlasInternal(
888       wrap::rocblas_sgbmv, stream, /* pointer_mode_host = */ true,
889       ROCMBlasTranspose(trans), m, n, kl, ku, &alpha, GpuMemory(a), lda,
890       GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy);
891 }
892 
DoBlasGbmv(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)893 bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
894                           uint64 n, uint64 kl, uint64 ku, double alpha,
895                           const DeviceMemory<double> &a, int lda,
896                           const DeviceMemory<double> &x, int incx, double beta,
897                           DeviceMemory<double> *y, int incy) {
898   return DoBlasInternal(
899       wrap::rocblas_dgbmv, stream, /* pointer_mode_host = */ true,
900       ROCMBlasTranspose(trans), m, n, kl, ku, &alpha, GpuMemory(a), lda,
901       GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy);
902 }
903 
DoBlasGbmv(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)904 bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
905                           uint64 n, uint64 kl, uint64 ku,
906                           std::complex<float> alpha,
907                           const DeviceMemory<std::complex<float>> &a, int lda,
908                           const DeviceMemory<std::complex<float>> &x, int incx,
909                           std::complex<float> beta,
910                           DeviceMemory<std::complex<float>> *y, int incy) {
911   return DoBlasInternal(
912       wrap::rocblas_cgbmv, stream, /* pointer_mode_host = */ true,
913       ROCMBlasTranspose(trans), m, n, kl, ku, complex_cast(alpha),
914       complex_cast(a), lda, complex_cast(x), incx, complex_cast(beta),
915       complex_cast(y), incy);
916 }
917 
DoBlasGbmv(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)918 bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
919                           uint64 n, uint64 kl, uint64 ku,
920                           std::complex<double> alpha,
921                           const DeviceMemory<std::complex<double>> &a, int lda,
922                           const DeviceMemory<std::complex<double>> &x, int incx,
923                           std::complex<double> beta,
924                           DeviceMemory<std::complex<double>> *y, int incy) {
925   return DoBlasInternal(
926       wrap::rocblas_zgbmv, stream, /* pointer_mode_host = */ true,
927       ROCMBlasTranspose(trans), m, n, kl, ku, complex_cast(alpha),
928       complex_cast(a), lda, complex_cast(x), incx, complex_cast(beta),
929       complex_cast(y), incy);
930 }
931 
DoBlasGemv(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)932 bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
933                           uint64 n, float alpha, const DeviceMemory<float> &a,
934                           int lda, const DeviceMemory<float> &x, int incx,
935                           float beta, DeviceMemory<float> *y, int incy) {
936   blas_log("DoBlasGemv");
937   return DoBlasInternal(
938       wrap::rocblas_sgemv, stream, /* pointer_mode_host = */ true,
939       ROCMBlasTranspose(trans), m, n, &alpha, GpuMemory(a), lda, GpuMemory(x),
940       incx, &beta, GpuMemoryMutable(y), incy);
941 }
942 
DoBlasGemv(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)943 bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
944                           uint64 n, double alpha, const DeviceMemory<double> &a,
945                           int lda, const DeviceMemory<double> &x, int incx,
946                           double beta, DeviceMemory<double> *y, int incy) {
947   blas_log("DoBlasGemv");
948   return DoBlasInternal(
949       wrap::rocblas_dgemv, stream, /* pointer_mode_host = */ true,
950       ROCMBlasTranspose(trans), m, n, &alpha, GpuMemory(a), lda, GpuMemory(x),
951       incx, &beta, GpuMemoryMutable(y), incy);
952 }
953 
DoBlasGemv(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)954 bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
955                           uint64 n, std::complex<float> alpha,
956                           const DeviceMemory<std::complex<float>> &a, int lda,
957                           const DeviceMemory<std::complex<float>> &x, int incx,
958                           std::complex<float> beta,
959                           DeviceMemory<std::complex<float>> *y, int incy) {
960   blas_log("DoBlasGemv");
961   return DoBlasInternal(
962       wrap::rocblas_cgemv, stream, /* pointer_mode_host = */ true,
963       ROCMBlasTranspose(trans), m, n, complex_cast(alpha), complex_cast(a), lda,
964       complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy);
965 }
966 
DoBlasGemv(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)967 bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
968                           uint64 n, std::complex<double> alpha,
969                           const DeviceMemory<std::complex<double>> &a, int lda,
970                           const DeviceMemory<std::complex<double>> &x, int incx,
971                           std::complex<double> beta,
972                           DeviceMemory<std::complex<double>> *y, int incy) {
973   blas_log("DoBlasGemv\n");
974   return DoBlasInternal(
975       wrap::rocblas_zgemv, stream, /* pointer_mode_host = */ true,
976       ROCMBlasTranspose(trans), m, n, complex_cast(alpha), complex_cast(a), lda,
977       complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy);
978 }
979 
DoBlasGer(Stream * stream,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * a,int lda)980 bool ROCMBlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha,
981                          const DeviceMemory<float> &x, int incx,
982                          const DeviceMemory<float> &y, int incy,
983                          DeviceMemory<float> *a, int lda) {
984   return DoBlasInternal(
985       wrap::rocblas_sger, stream, /* pointer_mode_host = */ true, m, n, &alpha,
986       GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(a), lda);
987 }
988 
DoBlasGer(Stream * stream,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * a,int lda)989 bool ROCMBlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha,
990                          const DeviceMemory<double> &x, int incx,
991                          const DeviceMemory<double> &y, int incy,
992                          DeviceMemory<double> *a, int lda) {
993   return DoBlasInternal(
994       wrap::rocblas_dger, stream, /* pointer_mode_host = */ true, m, n, &alpha,
995       GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(a), lda);
996 }
997 
DoBlasGerc(Stream * stream,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)998 bool ROCMBlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n,
999                           std::complex<float> alpha,
1000                           const DeviceMemory<std::complex<float>> &x, int incx,
1001                           const DeviceMemory<std::complex<float>> &y, int incy,
1002                           DeviceMemory<std::complex<float>> *a, int lda) {
1003   return DoBlasInternal(wrap::rocblas_cgerc, stream,
1004                         /* pointer_mode_host = */ true, m, n,
1005                         complex_cast(alpha), complex_cast(x), incx,
1006                         complex_cast(y), incy, complex_cast(a), lda);
1007 }
1008 
DoBlasGerc(Stream * stream,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)1009 bool ROCMBlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n,
1010                           std::complex<double> alpha,
1011                           const DeviceMemory<std::complex<double>> &x, int incx,
1012                           const DeviceMemory<std::complex<double>> &y, int incy,
1013                           DeviceMemory<std::complex<double>> *a, int lda) {
1014   return DoBlasInternal(wrap::rocblas_zgerc, stream,
1015                         /* pointer_mode_host = */ true, m, n,
1016                         complex_cast(alpha), complex_cast(x), incx,
1017                         complex_cast(y), incy, complex_cast(a), lda);
1018 }
1019 
DoBlasGeru(Stream * stream,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)1020 bool ROCMBlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n,
1021                           std::complex<float> alpha,
1022                           const DeviceMemory<std::complex<float>> &x, int incx,
1023                           const DeviceMemory<std::complex<float>> &y, int incy,
1024                           DeviceMemory<std::complex<float>> *a, int lda) {
1025   return DoBlasInternal(wrap::rocblas_cgeru, stream,
1026                         /* pointer_mode_host = */ true, m, n,
1027                         complex_cast(alpha), complex_cast(x), incx,
1028                         complex_cast(y), incy, complex_cast(a), lda);
1029 }
1030 
DoBlasGeru(Stream * stream,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)1031 bool ROCMBlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n,
1032                           std::complex<double> alpha,
1033                           const DeviceMemory<std::complex<double>> &x, int incx,
1034                           const DeviceMemory<std::complex<double>> &y, int incy,
1035                           DeviceMemory<std::complex<double>> *a, int lda) {
1036   return DoBlasInternal(wrap::rocblas_zgeru, stream,
1037                         /* pointer_mode_host = */ true, m, n,
1038                         complex_cast(alpha), complex_cast(x), incx,
1039                         complex_cast(y), incy, complex_cast(a), lda);
1040 }
1041 
DoBlasHbmv(Stream * stream,blas::UpperLower uplo,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)1042 bool ROCMBlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1043                           uint64 k, std::complex<float> alpha,
1044                           const DeviceMemory<std::complex<float>> &a, int lda,
1045                           const DeviceMemory<std::complex<float>> &x, int incx,
1046                           std::complex<float> beta,
1047                           DeviceMemory<std::complex<float>> *y, int incy) {
1048   return DoBlasInternal(
1049       wrap::rocblas_chbmv, stream, /* pointer_mode_host = */ true,
1050       ROCMBlasUpperLower(uplo), n, k, complex_cast(alpha), complex_cast(a), lda,
1051       complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy);
1052 }
1053 
DoBlasHbmv(Stream * stream,blas::UpperLower uplo,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)1054 bool ROCMBlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1055                           uint64 k, std::complex<double> alpha,
1056                           const DeviceMemory<std::complex<double>> &a, int lda,
1057                           const DeviceMemory<std::complex<double>> &x, int incx,
1058                           std::complex<double> beta,
1059                           DeviceMemory<std::complex<double>> *y, int incy) {
1060   return DoBlasInternal(
1061       wrap::rocblas_zhbmv, stream, /* pointer_mode_host = */ true,
1062       ROCMBlasUpperLower(uplo), n, k, complex_cast(alpha), complex_cast(a), lda,
1063       complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy);
1064 }
1065 
DoBlasHemv(Stream * stream,blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)1066 bool ROCMBlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
1067                           std::complex<float> alpha,
1068                           const DeviceMemory<std::complex<float>> &a, int lda,
1069                           const DeviceMemory<std::complex<float>> &x, int incx,
1070                           std::complex<float> beta,
1071                           DeviceMemory<std::complex<float>> *y, int incy) {
1072   return DoBlasInternal(
1073       wrap::rocblas_chemv, stream, /* pointer_mode_host = */ true,
1074       ROCMBlasUpperLower(uplo), n, complex_cast(alpha), complex_cast(a), lda,
1075       complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy);
1076 }
1077 
DoBlasHemv(Stream * stream,blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)1078 bool ROCMBlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
1079                           std::complex<double> alpha,
1080                           const DeviceMemory<std::complex<double>> &a, int lda,
1081                           const DeviceMemory<std::complex<double>> &x, int incx,
1082                           std::complex<double> beta,
1083                           DeviceMemory<std::complex<double>> *y, int incy) {
1084   return DoBlasInternal(
1085       wrap::rocblas_zhemv, stream, /* pointer_mode_host = */ true,
1086       ROCMBlasUpperLower(uplo), n, complex_cast(alpha), complex_cast(a), lda,
1087       complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy);
1088 }
1089 
DoBlasHer(Stream * stream,blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * a,int lda)1090 bool ROCMBlas::DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
1091                          float alpha,
1092                          const DeviceMemory<std::complex<float>> &x, int incx,
1093                          DeviceMemory<std::complex<float>> *a, int lda) {
1094   return DoBlasInternal(wrap::rocblas_cher, stream,
1095                         /* pointer_mode_host = */ true,
1096                         ROCMBlasUpperLower(uplo), n, complex_cast(alpha),
1097                         complex_cast(x), incx, complex_cast(a), lda);
1098 }
1099 
DoBlasHer(Stream * stream,blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * a,int lda)1100 bool ROCMBlas::DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
1101                          double alpha,
1102                          const DeviceMemory<std::complex<double>> &x, int incx,
1103                          DeviceMemory<std::complex<double>> *a, int lda) {
1104   return DoBlasInternal(wrap::rocblas_zher, stream,
1105                         /* pointer_mode_host = */ true,
1106                         ROCMBlasUpperLower(uplo), n, complex_cast(alpha),
1107                         complex_cast(x), incx, complex_cast(a), lda);
1108 }
1109 
DoBlasHer2(Stream * stream,blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)1110 bool ROCMBlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
1111                           std::complex<float> alpha,
1112                           const DeviceMemory<std::complex<float>> &x, int incx,
1113                           const DeviceMemory<std::complex<float>> &y, int incy,
1114                           DeviceMemory<std::complex<float>> *a, int lda) {
1115   return DoBlasInternal(
1116       wrap::rocblas_cher2, stream, /* pointer_mode_host = */ true,
1117       ROCMBlasUpperLower(uplo), n, complex_cast(alpha), complex_cast(x), incx,
1118       complex_cast(y), incy, complex_cast(a), lda);
1119 }
1120 
DoBlasHer2(Stream * stream,blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)1121 bool ROCMBlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
1122                           std::complex<double> alpha,
1123                           const DeviceMemory<std::complex<double>> &x, int incx,
1124                           const DeviceMemory<std::complex<double>> &y, int incy,
1125                           DeviceMemory<std::complex<double>> *a, int lda) {
1126   return DoBlasInternal(
1127       wrap::rocblas_zher2, stream, /* pointer_mode_host = */ true,
1128       ROCMBlasUpperLower(uplo), n, complex_cast(alpha), complex_cast(x), incx,
1129       complex_cast(y), incy, complex_cast(a), lda);
1130 }
1131 
DoBlasHpmv(Stream * stream,blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & ap,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)1132 bool ROCMBlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1133                           std::complex<float> alpha,
1134                           const DeviceMemory<std::complex<float>> &ap,
1135                           const DeviceMemory<std::complex<float>> &x, int incx,
1136                           std::complex<float> beta,
1137                           DeviceMemory<std::complex<float>> *y, int incy) {
1138   return DoBlasInternal(
1139       wrap::rocblas_chpmv, stream, /* pointer_mode_host = */ true,
1140       ROCMBlasUpperLower(uplo), n, complex_cast(alpha), complex_cast(ap),
1141       complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy);
1142 }
1143 
DoBlasHpmv(Stream * stream,blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & ap,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)1144 bool ROCMBlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1145                           std::complex<double> alpha,
1146                           const DeviceMemory<std::complex<double>> &ap,
1147                           const DeviceMemory<std::complex<double>> &x, int incx,
1148                           std::complex<double> beta,
1149                           DeviceMemory<std::complex<double>> *y, int incy) {
1150   return DoBlasInternal(
1151       wrap::rocblas_zhpmv, stream, /* pointer_mode_host = */ true,
1152       ROCMBlasUpperLower(uplo), n, complex_cast(alpha), complex_cast(ap),
1153       complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy);
1154 }
1155 
DoBlasHpr(Stream * stream,blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * ap)1156 bool ROCMBlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
1157                          float alpha,
1158                          const DeviceMemory<std::complex<float>> &x, int incx,
1159                          DeviceMemory<std::complex<float>> *ap) {
1160   return DoBlasInternal(wrap::rocblas_chpr, stream,
1161                         /* pointer_mode_host = */ true,
1162                         ROCMBlasUpperLower(uplo), n, complex_cast(alpha),
1163                         complex_cast(x), incx, complex_cast(ap));
1164 }
1165 
DoBlasHpr(Stream * stream,blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * ap)1166 bool ROCMBlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
1167                          double alpha,
1168                          const DeviceMemory<std::complex<double>> &x, int incx,
1169                          DeviceMemory<std::complex<double>> *ap) {
1170   return DoBlasInternal(wrap::rocblas_zhpr, stream,
1171                         /* pointer_mode_host = */ true,
1172                         ROCMBlasUpperLower(uplo), n, complex_cast(alpha),
1173                         complex_cast(x), incx, complex_cast(ap));
1174 }
1175 
DoBlasHpr2(Stream * stream,blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * ap)1176 bool ROCMBlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1177                           std::complex<float> alpha,
1178                           const DeviceMemory<std::complex<float>> &x, int incx,
1179                           const DeviceMemory<std::complex<float>> &y, int incy,
1180                           DeviceMemory<std::complex<float>> *ap) {
1181   return DoBlasInternal(
1182       wrap::rocblas_chpr2, stream, /* pointer_mode_host = */ true,
1183       ROCMBlasUpperLower(uplo), n, complex_cast(alpha), complex_cast(x), incx,
1184       complex_cast(y), incy, complex_cast(ap));
1185 }
1186 
DoBlasHpr2(Stream * stream,blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * ap)1187 bool ROCMBlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1188                           std::complex<double> alpha,
1189                           const DeviceMemory<std::complex<double>> &x, int incx,
1190                           const DeviceMemory<std::complex<double>> &y, int incy,
1191                           DeviceMemory<std::complex<double>> *ap) {
1192   return DoBlasInternal(
1193       wrap::rocblas_zhpr2, stream, /* pointer_mode_host = */ true,
1194       ROCMBlasUpperLower(uplo), n, complex_cast(alpha), complex_cast(x), incx,
1195       complex_cast(y), incy, complex_cast(ap));
1196 }
1197 
DoBlasSbmv(Stream * stream,blas::UpperLower uplo,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)1198 bool ROCMBlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1199                           uint64 k, float alpha, const DeviceMemory<float> &a,
1200                           int lda, const DeviceMemory<float> &x, int incx,
1201                           float beta, DeviceMemory<float> *y, int incy) {
1202   return DoBlasInternal(
1203       wrap::rocblas_ssbmv, stream, /* pointer_mode_host = */ true,
1204       ROCMBlasUpperLower(uplo), n, k, &alpha, GpuMemory(a), lda, GpuMemory(x),
1205       incx, &beta, GpuMemoryMutable(y), incy);
1206 }
1207 
DoBlasSbmv(Stream * stream,blas::UpperLower uplo,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)1208 bool ROCMBlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1209                           uint64 k, double alpha, const DeviceMemory<double> &a,
1210                           int lda, const DeviceMemory<double> &x, int incx,
1211                           double beta, DeviceMemory<double> *y, int incy) {
1212   return DoBlasInternal(
1213       wrap::rocblas_dsbmv, stream, /* pointer_mode_host = */ true,
1214       ROCMBlasUpperLower(uplo), n, k, &alpha, GpuMemory(a), lda, GpuMemory(x),
1215       incx, &beta, GpuMemoryMutable(y), incy);
1216 }
1217 
DoBlasSpmv(Stream * stream,blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & ap,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)1218 bool ROCMBlas::DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1219                           float alpha, const DeviceMemory<float> &ap,
1220                           const DeviceMemory<float> &x, int incx, float beta,
1221                           DeviceMemory<float> *y, int incy) {
1222   return DoBlasInternal(wrap::rocblas_sspmv, stream,
1223                         /* pointer_mode_host = */ true,
1224                         ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(ap),
1225                         GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy);
1226 }
1227 
DoBlasSpmv(Stream * stream,blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & ap,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)1228 bool ROCMBlas::DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1229                           double alpha, const DeviceMemory<double> &ap,
1230                           const DeviceMemory<double> &x, int incx, double beta,
1231                           DeviceMemory<double> *y, int incy) {
1232   return DoBlasInternal(wrap::rocblas_dspmv, stream,
1233                         /* pointer_mode_host = */ true,
1234                         ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(ap),
1235                         GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy);
1236 }
1237 
DoBlasSpr(Stream * stream,blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * ap)1238 bool ROCMBlas::DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
1239                          float alpha, const DeviceMemory<float> &x, int incx,
1240                          DeviceMemory<float> *ap) {
1241   return DoBlasInternal(wrap::rocblas_sspr, stream,
1242                         /* pointer_mode_host = */ true,
1243                         ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1244                         GpuMemoryMutable(ap));
1245 }
1246 
DoBlasSpr(Stream * stream,blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * ap)1247 bool ROCMBlas::DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
1248                          double alpha, const DeviceMemory<double> &x, int incx,
1249                          DeviceMemory<double> *ap) {
1250   return DoBlasInternal(wrap::rocblas_dspr, stream,
1251                         /* pointer_mode_host = */ true,
1252                         ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1253                         GpuMemoryMutable(ap));
1254 }
1255 
DoBlasSpr2(Stream * stream,blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * ap)1256 bool ROCMBlas::DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1257                           float alpha, const DeviceMemory<float> &x, int incx,
1258                           const DeviceMemory<float> &y, int incy,
1259                           DeviceMemory<float> *ap) {
1260   return DoBlasInternal(wrap::rocblas_sspr2, stream,
1261                         /* pointer_mode_host = */ true,
1262                         ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1263                         GpuMemory(y), incy, GpuMemoryMutable(ap));
1264 }
1265 
DoBlasSpr2(Stream * stream,blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * ap)1266 bool ROCMBlas::DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1267                           double alpha, const DeviceMemory<double> &x, int incx,
1268                           const DeviceMemory<double> &y, int incy,
1269                           DeviceMemory<double> *ap) {
1270   return DoBlasInternal(wrap::rocblas_dspr2, stream,
1271                         /* pointer_mode_host = */ true,
1272                         ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1273                         GpuMemory(y), incy, GpuMemoryMutable(ap));
1274 }
1275 
DoBlasSymv(Stream * stream,blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)1276 bool ROCMBlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
1277                           float alpha, const DeviceMemory<float> &a, int lda,
1278                           const DeviceMemory<float> &x, int incx, float beta,
1279                           DeviceMemory<float> *y, int incy) {
1280   return DoBlasInternal(wrap::rocblas_ssymv, stream,
1281                         /* pointer_mode_host = */ true,
1282                         ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(a), lda,
1283                         GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy);
1284 }
1285 
DoBlasSymv(Stream * stream,blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)1286 bool ROCMBlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
1287                           double alpha, const DeviceMemory<double> &a, int lda,
1288                           const DeviceMemory<double> &x, int incx, double beta,
1289                           DeviceMemory<double> *y, int incy) {
1290   return DoBlasInternal(wrap::rocblas_dsymv, stream,
1291                         /* pointer_mode_host = */ true,
1292                         ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(a), lda,
1293                         GpuMemory(x), incx, &beta, GpuMemoryMutable(y), incy);
1294 }
1295 
DoBlasSyr(Stream * stream,blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * a,int lda)1296 bool ROCMBlas::DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
1297                          float alpha, const DeviceMemory<float> &x, int incx,
1298                          DeviceMemory<float> *a, int lda) {
1299   return DoBlasInternal(wrap::rocblas_ssyr, stream,
1300                         /* pointer_mode_host = */ true,
1301                         ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1302                         GpuMemoryMutable(a), lda);
1303 }
1304 
DoBlasSyr(Stream * stream,blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * a,int lda)1305 bool ROCMBlas::DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
1306                          double alpha, const DeviceMemory<double> &x, int incx,
1307                          DeviceMemory<double> *a, int lda) {
1308   return DoBlasInternal(wrap::rocblas_dsyr, stream,
1309                         /* pointer_mode_host = */ true,
1310                         ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1311                         GpuMemoryMutable(a), lda);
1312 }
1313 
DoBlasSyr2(Stream * stream,blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * a,int lda)1314 bool ROCMBlas::DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1315                           float alpha, const DeviceMemory<float> &x, int incx,
1316                           const DeviceMemory<float> &y, int incy,
1317                           DeviceMemory<float> *a, int lda) {
1318   return DoBlasInternal(wrap::rocblas_ssyr2, stream,
1319                         /* pointer_mode_host = */ true,
1320                         ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1321                         GpuMemory(y), incy, GpuMemoryMutable(a), lda);
1322 }
1323 
DoBlasSyr2(Stream * stream,blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * a,int lda)1324 bool ROCMBlas::DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1325                           double alpha, const DeviceMemory<double> &x, int incx,
1326                           const DeviceMemory<double> &y, int incy,
1327                           DeviceMemory<double> *a, int lda) {
1328   return DoBlasInternal(wrap::rocblas_dsyr2, stream,
1329                         /* pointer_mode_host = */ true,
1330                         ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1331                         GpuMemory(y), incy, GpuMemoryMutable(a), lda);
1332 }
1333 
DoBlasTbmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)1334 bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
1335                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1336                           uint64 k, const DeviceMemory<float> &a, int lda,
1337                           DeviceMemory<float> *x, int incx) {
1338   return DoBlasInternal(wrap::rocblas_stbmv, stream,
1339                         /* pointer_mode_host = */ false,
1340                         ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1341                         ROCMBlasDiagonal(diag), n, k, GpuMemory(a), lda,
1342                         GpuMemoryMutable(x), incx);
1343 }
1344 
DoBlasTbmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)1345 bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
1346                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1347                           uint64 k, const DeviceMemory<double> &a, int lda,
1348                           DeviceMemory<double> *x, int incx) {
1349   return DoBlasInternal(wrap::rocblas_dtbmv, stream,
1350                         /* pointer_mode_host = */ false,
1351                         ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1352                         ROCMBlasDiagonal(diag), n, k, GpuMemory(a), lda,
1353                         GpuMemoryMutable(x), incx);
1354 }
1355 
DoBlasTbmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)1356 bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
1357                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1358                           uint64 k, const DeviceMemory<std::complex<float>> &a,
1359                           int lda, DeviceMemory<std::complex<float>> *x,
1360                           int incx) {
1361   return DoBlasInternal(wrap::rocblas_ctbmv, stream,
1362                         /* pointer_mode_host = */ false,
1363                         ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1364                         ROCMBlasDiagonal(diag), n, k, complex_cast(a), lda,
1365                         complex_cast(x), incx);
1366 }
1367 
DoBlasTbmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)1368 bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
1369                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1370                           uint64 k, const DeviceMemory<std::complex<double>> &a,
1371                           int lda, DeviceMemory<std::complex<double>> *x,
1372                           int incx) {
1373   return DoBlasInternal(wrap::rocblas_ztbmv, stream,
1374                         /* pointer_mode_host = */ false,
1375                         ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1376                         ROCMBlasDiagonal(diag), n, k, complex_cast(a), lda,
1377                         complex_cast(x), incx);
1378 }
1379 
DoBlasTbsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)1380 bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
1381                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1382                           uint64 k, const DeviceMemory<float> &a, int lda,
1383                           DeviceMemory<float> *x, int incx) {
1384   return DoBlasInternal(wrap::rocblas_stbsv, stream,
1385                         /* pointer_mode_host = */ false,
1386                         ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1387                         ROCMBlasDiagonal(diag), n, k, GpuMemory(a), lda,
1388                         GpuMemoryMutable(x), incx);
1389 }
1390 
DoBlasTbsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)1391 bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
1392                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1393                           uint64 k, const DeviceMemory<double> &a, int lda,
1394                           DeviceMemory<double> *x, int incx) {
1395   return DoBlasInternal(wrap::rocblas_dtbsv, stream,
1396                         /* pointer_mode_host = */ false,
1397                         ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1398                         ROCMBlasDiagonal(diag), n, k, GpuMemory(a), lda,
1399                         GpuMemoryMutable(x), incx);
1400 }
1401 
DoBlasTbsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)1402 bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
1403                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1404                           uint64 k, const DeviceMemory<std::complex<float>> &a,
1405                           int lda, DeviceMemory<std::complex<float>> *x,
1406                           int incx) {
1407   return DoBlasInternal(wrap::rocblas_ctbsv, stream,
1408                         /* pointer_mode_host = */ false,
1409                         ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1410                         ROCMBlasDiagonal(diag), n, k, complex_cast(a), lda,
1411                         complex_cast(x), incx);
1412 }
1413 
DoBlasTbsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)1414 bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
1415                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1416                           uint64 k, const DeviceMemory<std::complex<double>> &a,
1417                           int lda, DeviceMemory<std::complex<double>> *x,
1418                           int incx) {
1419   return DoBlasInternal(wrap::rocblas_ztbsv, stream,
1420                         /* pointer_mode_host = */ false,
1421                         ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1422                         ROCMBlasDiagonal(diag), n, k, complex_cast(a), lda,
1423                         complex_cast(x), incx);
1424 }
1425 
DoBlasTpmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & ap,DeviceMemory<float> * x,int incx)1426 bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
1427                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1428                           const DeviceMemory<float> &ap, DeviceMemory<float> *x,
1429                           int incx) {
1430   return DoBlasInternal(
1431       wrap::rocblas_stpmv, stream, /* pointer_mode_host = */ false,
1432       ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1433       ROCMBlasDiagonal(diag), n, GpuMemory(ap), GpuMemoryMutable(x), incx);
1434 }
1435 
DoBlasTpmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & ap,DeviceMemory<double> * x,int incx)1436 bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
1437                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1438                           const DeviceMemory<double> &ap,
1439                           DeviceMemory<double> *x, int incx) {
1440   return DoBlasInternal(
1441       wrap::rocblas_dtpmv, stream, /* pointer_mode_host = */ false,
1442       ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1443       ROCMBlasDiagonal(diag), n, GpuMemory(ap), GpuMemoryMutable(x), incx);
1444 }
1445 
DoBlasTpmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & ap,DeviceMemory<std::complex<float>> * x,int incx)1446 bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
1447                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1448                           const DeviceMemory<std::complex<float>> &ap,
1449                           DeviceMemory<std::complex<float>> *x, int incx) {
1450   return DoBlasInternal(
1451       wrap::rocblas_ctpmv, stream, /* pointer_mode_host = */ false,
1452       ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1453       ROCMBlasDiagonal(diag), n, complex_cast(ap), complex_cast(x), incx);
1454 }
1455 
DoBlasTpmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & ap,DeviceMemory<std::complex<double>> * x,int incx)1456 bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
1457                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1458                           const DeviceMemory<std::complex<double>> &ap,
1459                           DeviceMemory<std::complex<double>> *x, int incx) {
1460   return DoBlasInternal(
1461       wrap::rocblas_ztpmv, stream, /* pointer_mode_host = */ false,
1462       ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1463       ROCMBlasDiagonal(diag), n, complex_cast(ap), complex_cast(x), incx);
1464 }
1465 
DoBlasTpsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & ap,DeviceMemory<float> * x,int incx)1466 bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
1467                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1468                           const DeviceMemory<float> &ap, DeviceMemory<float> *x,
1469                           int incx) {
1470   return DoBlasInternal(
1471       wrap::rocblas_stpsv, stream, /* pointer_mode_host = */ false,
1472       ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1473       ROCMBlasDiagonal(diag), n, GpuMemory(ap), GpuMemoryMutable(x), incx);
1474 }
1475 
DoBlasTpsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & ap,DeviceMemory<double> * x,int incx)1476 bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
1477                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1478                           const DeviceMemory<double> &ap,
1479                           DeviceMemory<double> *x, int incx) {
1480   return DoBlasInternal(
1481       wrap::rocblas_dtpsv, stream, /* pointer_mode_host = */ false,
1482       ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1483       ROCMBlasDiagonal(diag), n, GpuMemory(ap), GpuMemoryMutable(x), incx);
1484 }
1485 
DoBlasTpsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & ap,DeviceMemory<std::complex<float>> * x,int incx)1486 bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
1487                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1488                           const DeviceMemory<std::complex<float>> &ap,
1489                           DeviceMemory<std::complex<float>> *x, int incx) {
1490   return DoBlasInternal(
1491       wrap::rocblas_ctpsv, stream, /* pointer_mode_host = */ false,
1492       ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1493       ROCMBlasDiagonal(diag), n, complex_cast(ap), complex_cast(x), incx);
1494 }
1495 
DoBlasTpsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & ap,DeviceMemory<std::complex<double>> * x,int incx)1496 bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
1497                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1498                           const DeviceMemory<std::complex<double>> &ap,
1499                           DeviceMemory<std::complex<double>> *x, int incx) {
1500   return DoBlasInternal(
1501       wrap::rocblas_ztpsv, stream, /* pointer_mode_host = */ false,
1502       ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1503       ROCMBlasDiagonal(diag), n, complex_cast(ap), complex_cast(x), incx);
1504 }
1505 
DoBlasTrmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)1506 bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
1507                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1508                           const DeviceMemory<float> &a, int lda,
1509                           DeviceMemory<float> *x, int incx) {
1510   return DoBlasInternal(
1511       wrap::rocblas_strmv, stream, /* pointer_mode_host = */ false,
1512       ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1513       ROCMBlasDiagonal(diag), n, GpuMemory(a), lda, GpuMemoryMutable(x), incx);
1514 }
1515 
DoBlasTrmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)1516 bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
1517                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1518                           const DeviceMemory<double> &a, int lda,
1519                           DeviceMemory<double> *x, int incx) {
1520   return DoBlasInternal(
1521       wrap::rocblas_dtrmv, stream, /* pointer_mode_host = */ false,
1522       ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1523       ROCMBlasDiagonal(diag), n, GpuMemory(a), lda, GpuMemoryMutable(x), incx);
1524 }
1525 
DoBlasTrmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)1526 bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
1527                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1528                           const DeviceMemory<std::complex<float>> &a, int lda,
1529                           DeviceMemory<std::complex<float>> *x, int incx) {
1530   return DoBlasInternal(
1531       wrap::rocblas_ctrmv, stream, /* pointer_mode_host = */ false,
1532       ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1533       ROCMBlasDiagonal(diag), n, complex_cast(a), lda, complex_cast(x), incx);
1534 }
1535 
DoBlasTrmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)1536 bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
1537                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1538                           const DeviceMemory<std::complex<double>> &a, int lda,
1539                           DeviceMemory<std::complex<double>> *x, int incx) {
1540   return DoBlasInternal(
1541       wrap::rocblas_ztrmv, stream, /* pointer_mode_host = */ false,
1542       ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1543       ROCMBlasDiagonal(diag), n, complex_cast(a), lda, complex_cast(x), incx);
1544 }
1545 
DoBlasTrsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)1546 bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1547                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1548                           const DeviceMemory<float> &a, int lda,
1549                           DeviceMemory<float> *x, int incx) {
1550   return DoBlasInternal(
1551       wrap::rocblas_strsv, stream, /* pointer_mode_host = */ false,
1552       ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1553       ROCMBlasDiagonal(diag), n, GpuMemory(a), lda, GpuMemoryMutable(x), incx);
1554 }
1555 
DoBlasTrsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)1556 bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1557                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1558                           const DeviceMemory<double> &a, int lda,
1559                           DeviceMemory<double> *x, int incx) {
1560   return DoBlasInternal(
1561       wrap::rocblas_dtrsv, stream, /* pointer_mode_host = */ false,
1562       ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1563       ROCMBlasDiagonal(diag), n, GpuMemory(a), lda, GpuMemoryMutable(x), incx);
1564 }
1565 
DoBlasTrsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)1566 bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1567                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1568                           const DeviceMemory<std::complex<float>> &a, int lda,
1569                           DeviceMemory<std::complex<float>> *x, int incx) {
1570   return DoBlasInternal(
1571       wrap::rocblas_ctrsv, stream, /* pointer_mode_host = */ false,
1572       ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1573       ROCMBlasDiagonal(diag), n, complex_cast(a), lda, complex_cast(x), incx);
1574 }
1575 
DoBlasTrsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)1576 bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1577                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1578                           const DeviceMemory<std::complex<double>> &a, int lda,
1579                           DeviceMemory<std::complex<double>> *x, int incx) {
1580   return DoBlasInternal(
1581       wrap::rocblas_ztrsv, stream, /* pointer_mode_host = */ false,
1582       ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans),
1583       ROCMBlasDiagonal(diag), n, complex_cast(a), lda, complex_cast(x), incx);
1584 }
1585 
DoBlasGemm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,float beta,DeviceMemory<Eigen::half> * c,int ldc)1586 bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
1587                           blas::Transpose transb, uint64 m, uint64 n, uint64 k,
1588                           float alpha, const DeviceMemory<Eigen::half> &a,
1589                           int lda, const DeviceMemory<Eigen::half> &b, int ldb,
1590                           float beta, DeviceMemory<Eigen::half> *c, int ldc) {
1591   blas_log("DoBlasGemm");
1592   VLOG(1) << absl::StreamFormat(
1593       "doing rocBLAS SGEMM<half>: at=%d bt=%d m=%u n=%u "
1594       "k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
1595       "c=%p ldc=%d",
1596       static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
1597       a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
1598   if (transa == blas::Transpose::kNoTranspose) {
1599     if (lda < static_cast<int64>(m)) {
1600       LOG(WARNING) << "GEMM lda was smaller than m (no transpose case); "
1601                       "precondition violation";
1602     }
1603   } else {
1604     if (lda < static_cast<int64>(k)) {
1605       LOG(WARNING) << "GEMM lda (" << lda << ") was smaller than k (" << k
1606                    << ") (transpose case); precondition violation";
1607     }
1608   }
1609   if (transb == blas::Transpose::kNoTranspose) {
1610     if (ldb < static_cast<int64>(k)) {
1611       LOG(WARNING) << "GEMM ldb (" << ldb << ") was smaller than k (" << k
1612                    << ") (no transpose case); precondition violation";
1613     }
1614   } else {
1615     if (ldb < static_cast<int64>(n)) {
1616       LOG(WARNING) << "GEMM ldb was smaller than n (transpose case); "
1617                       "precondition violation";
1618     }
1619   }
1620   const Eigen::half alpha_half(alpha);
1621   const Eigen::half beta_half(beta);
1622   return DoBlasInternal(
1623       wrap::rocblas_hgemm, stream, /* pointer_mode_host = */ true,
1624       ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
1625       reinterpret_cast<const rocblas_half *>(&alpha_half),
1626       reinterpret_cast<const rocblas_half *>(GpuMemory(a)), lda,
1627       reinterpret_cast<const rocblas_half *>(GpuMemory(b)), ldb,
1628       reinterpret_cast<const rocblas_half *>(&beta_half),
1629       reinterpret_cast<rocblas_half *>(GpuMemoryMutable(c)), ldc);
1630 }
1631 
DoBlasGemm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)1632 bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
1633                           blas::Transpose transb, uint64 m, uint64 n, uint64 k,
1634                           float alpha, const DeviceMemory<float> &a, int lda,
1635                           const DeviceMemory<float> &b, int ldb, float beta,
1636                           DeviceMemory<float> *c, int ldc) {
1637   blas_log("DoBlasGemm");
1638   VLOG(1) << absl::StreamFormat(
1639       "doing rocBLAS SGEMM<float>: at=%d bt=%d m=%u n=%u "
1640       "k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
1641       "c=%p ldc=%d",
1642       static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
1643       a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
1644   if (transa == blas::Transpose::kNoTranspose) {
1645     if (lda < static_cast<int64>(m)) {
1646       LOG(WARNING) << "GEMM lda was smaller than m (no transpose case); "
1647                       "precondition violation";
1648     }
1649   } else {
1650     if (lda < static_cast<int64>(k)) {
1651       LOG(WARNING) << "GEMM lda (" << lda << ") was smaller than k (" << k
1652                    << ") (transpose case); precondition violation";
1653     }
1654   }
1655   if (transb == blas::Transpose::kNoTranspose) {
1656     if (ldb < static_cast<int64>(k)) {
1657       LOG(WARNING) << "GEMM ldb (" << ldb << ") was smaller than k (" << k
1658                    << ") (no transpose case); precondition violation";
1659     }
1660   } else {
1661     if (ldb < static_cast<int64>(n)) {
1662       LOG(WARNING) << "GEMM ldb was smaller than n (transpose case); "
1663                       "precondition violation";
1664     }
1665   }
1666   return DoBlasInternal(
1667       wrap::rocblas_sgemm, stream, /* pointer_mode_host = */ true,
1668       ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, &alpha,
1669       GpuMemory(a), lda, GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc);
1670 }
1671 
DoBlasGemm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)1672 bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
1673                           blas::Transpose transb, uint64 m, uint64 n, uint64 k,
1674                           double alpha, const DeviceMemory<double> &a, int lda,
1675                           const DeviceMemory<double> &b, int ldb, double beta,
1676                           DeviceMemory<double> *c, int ldc) {
1677   blas_log("DoBlasGemm");
1678   return DoBlasInternal(
1679       wrap::rocblas_dgemm, stream, /* pointer_mode_host = */ true,
1680       ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, &alpha,
1681       GpuMemory(a), lda, GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc);
1682 }
1683 
DoBlasGemm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)1684 bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
1685                           blas::Transpose transb, uint64 m, uint64 n, uint64 k,
1686                           std::complex<float> alpha,
1687                           const DeviceMemory<std::complex<float>> &a, int lda,
1688                           const DeviceMemory<std::complex<float>> &b, int ldb,
1689                           std::complex<float> beta,
1690                           DeviceMemory<std::complex<float>> *c, int ldc) {
1691   blas_log("DoBlasGemm");
1692   return DoBlasInternal(
1693       wrap::rocblas_cgemm, stream, /* pointer_mode_host = */ true,
1694       ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
1695       complex_cast(alpha), complex_cast(a), lda, complex_cast(b), ldb,
1696       complex_cast(beta), complex_cast(c), ldc);
1697 }
1698 
DoBlasGemm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)1699 bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
1700                           blas::Transpose transb, uint64 m, uint64 n, uint64 k,
1701                           std::complex<double> alpha,
1702                           const DeviceMemory<std::complex<double>> &a, int lda,
1703                           const DeviceMemory<std::complex<double>> &b, int ldb,
1704                           std::complex<double> beta,
1705                           DeviceMemory<std::complex<double>> *c, int ldc) {
1706   blas_log("DoBlasGemm");
1707   return DoBlasInternal(
1708       wrap::rocblas_zgemm, stream, /* pointer_mode_host = */ true,
1709       ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
1710       complex_cast(alpha), complex_cast(a), lda, complex_cast(b), ldb,
1711       complex_cast(beta), complex_cast(c), ldc);
1712 }
1713 
DoBlasGemvWithProfiling(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy,blas::ProfileResult * output_profile_result)1714 bool ROCMBlas::DoBlasGemvWithProfiling(
1715     Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha,
1716     const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
1717     int incx, float beta, DeviceMemory<float> *y, int incy,
1718     blas::ProfileResult *output_profile_result) {
1719   return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
1720                                      incx, beta, y, incy,
1721                                      output_profile_result);
1722 }
1723 
DoBlasGemvWithProfiling(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy,blas::ProfileResult * output_profile_result)1724 bool ROCMBlas::DoBlasGemvWithProfiling(
1725     Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha,
1726     const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
1727     int incx, double beta, DeviceMemory<double> *y, int incy,
1728     blas::ProfileResult *output_profile_result) {
1729   return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
1730                                      incx, beta, y, incy,
1731                                      output_profile_result);
1732 }
1733 
DoBlasGemvWithProfiling(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy,blas::ProfileResult * output_profile_result)1734 bool ROCMBlas::DoBlasGemvWithProfiling(
1735     Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
1736     std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a,
1737     int lda, const DeviceMemory<std::complex<float>> &x, int incx,
1738     std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
1739     blas::ProfileResult *output_profile_result) {
1740   return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
1741                                      incx, beta, y, incy,
1742                                      output_profile_result);
1743 }
1744 
DoBlasGemvWithProfiling(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy,blas::ProfileResult * output_profile_result)1745 bool ROCMBlas::DoBlasGemvWithProfiling(
1746     Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
1747     std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a,
1748     int lda, const DeviceMemory<std::complex<double>> &x, int incx,
1749     std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy,
1750     blas::ProfileResult *output_profile_result) {
1751   return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
1752                                      incx, beta, y, incy,
1753                                      output_profile_result);
1754 }
1755 
DoBlasGemmWithProfiling(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,float beta,DeviceMemory<Eigen::half> * c,int ldc,blas::ProfileResult * output_profile_result)1756 bool ROCMBlas::DoBlasGemmWithProfiling(
1757     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1758     uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
1759     int lda, const DeviceMemory<Eigen::half> &b, int ldb, float beta,
1760     DeviceMemory<Eigen::half> *c, int ldc,
1761     blas::ProfileResult *output_profile_result) {
1762   return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1763                                      lda, b, ldb, beta, c, ldc,
1764                                      output_profile_result);
1765 }
1766 
DoBlasGemmWithProfiling(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc,blas::ProfileResult * output_profile_result)1767 bool ROCMBlas::DoBlasGemmWithProfiling(
1768     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1769     uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
1770     const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
1771     int ldc, blas::ProfileResult *output_profile_result) {
1772   return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1773                                      lda, b, ldb, beta, c, ldc,
1774                                      output_profile_result);
1775 }
1776 
DoBlasGemmWithProfiling(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc,blas::ProfileResult * output_profile_result)1777 bool ROCMBlas::DoBlasGemmWithProfiling(
1778     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1779     uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
1780     const DeviceMemory<double> &b, int ldb, double beta,
1781     DeviceMemory<double> *c, int ldc,
1782     blas::ProfileResult *output_profile_result) {
1783   return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1784                                      lda, b, ldb, beta, c, ldc,
1785                                      output_profile_result);
1786 }
1787 
DoBlasGemmWithProfiling(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc,blas::ProfileResult * output_profile_result)1788 bool ROCMBlas::DoBlasGemmWithProfiling(
1789     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1790     uint64 n, uint64 k, std::complex<float> alpha,
1791     const DeviceMemory<std::complex<float>> &a, int lda,
1792     const DeviceMemory<std::complex<float>> &b, int ldb,
1793     std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
1794     blas::ProfileResult *output_profile_result) {
1795   return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1796                                      lda, b, ldb, beta, c, ldc,
1797                                      output_profile_result);
1798 }
1799 
DoBlasGemmWithProfiling(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc,blas::ProfileResult * output_profile_result)1800 bool ROCMBlas::DoBlasGemmWithProfiling(
1801     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1802     uint64 n, uint64 k, std::complex<double> alpha,
1803     const DeviceMemory<std::complex<double>> &a, int lda,
1804     const DeviceMemory<std::complex<double>> &b, int ldb,
1805     std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
1806     blas::ProfileResult *output_profile_result) {
1807   return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1808                                      lda, b, ldb, beta, c, ldc,
1809                                      output_profile_result);
1810 }
1811 
1812 template <typename T>
DoBlasGemvWithProfilingImpl(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,const T & alpha,const DeviceMemory<T> & a,int lda,const DeviceMemory<T> & x,int incx,const T & beta,DeviceMemory<T> * y,int incy,blas::ProfileResult * output_profile_result)1813 bool ROCMBlas::DoBlasGemvWithProfilingImpl(
1814     Stream *stream, blas::Transpose trans, uint64 m, uint64 n, const T &alpha,
1815     const DeviceMemory<T> &a, int lda, const DeviceMemory<T> &x, int incx,
1816     const T &beta, DeviceMemory<T> *y, int incy,
1817     blas::ProfileResult *output_profile_result) {
1818   // ROCM TODO: properly implement the interface
1819   return false;
1820 }
1821 
1822 template <typename T, typename ParamType>
DoBlasGemmWithProfilingImpl(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const ParamType & alpha,const DeviceMemory<T> & a,int lda,const DeviceMemory<T> & b,int ldb,const ParamType & beta,DeviceMemory<T> * c,int ldc,blas::ProfileResult * output_profile_result)1823 bool ROCMBlas::DoBlasGemmWithProfilingImpl(
1824     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1825     uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a,
1826     int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta,
1827     DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result) {
1828   // ROCM TODO: properly implement the interface
1829   return false;
1830 }
1831 
1832 template <typename InT, typename OutT, typename CompT>
DoBlasGemmWithAlgorithmImpl(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const CompT & alpha,const DeviceMemory<InT> & a,int lda,const DeviceMemory<InT> & b,int ldb,const CompT & beta,DeviceMemory<OutT> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)1833 bool ROCMBlas::DoBlasGemmWithAlgorithmImpl(
1834     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1835     uint64 n, uint64 k, const CompT &alpha, const DeviceMemory<InT> &a, int lda,
1836     const DeviceMemory<InT> &b, int ldb, const CompT &beta,
1837     DeviceMemory<OutT> *c, int ldc, blas::ComputationType computation_type,
1838     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
1839   // ROCM TODO: properly implement the interface
1840   return false;
1841 }
1842 
GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> * out_algorithms)1843 bool ROCMBlas::GetBlasGemmAlgorithms(
1844     std::vector<blas::AlgorithmType> *out_algorithms) {
1845   // ROCM TODO: properly implement the interface
1846   return true;
1847 }
1848 
DoBlasGemmWithAlgorithm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<int> & alpha,const DeviceMemory<int8> & a,int lda,const DeviceMemory<int8> & b,int ldb,const HostOrDeviceScalar<int> & beta,DeviceMemory<int32> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)1849 bool ROCMBlas::DoBlasGemmWithAlgorithm(
1850     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1851     uint64 n, uint64 k, const HostOrDeviceScalar<int> &alpha,
1852     const DeviceMemory<int8> &a, int lda, const DeviceMemory<int8> &b, int ldb,
1853     const HostOrDeviceScalar<int> &beta, DeviceMemory<int32> *c, int ldc,
1854     blas::ComputationType computation_type, blas::AlgorithmType algorithm,
1855     blas::ProfileResult *output_profile_result) {
1856   LOG(ERROR)
1857       << "rocBLAS does not currently support the GEMMwithAlgorithm operation "
1858       << "for the \"int8\" datatype";
1859   return false;
1860 }
1861 
DoBlasGemmWithAlgorithm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<Eigen::half> & alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,const HostOrDeviceScalar<Eigen::half> & beta,DeviceMemory<Eigen::half> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)1862 bool ROCMBlas::DoBlasGemmWithAlgorithm(
1863     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1864     uint64 n, uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha,
1865     const DeviceMemory<Eigen::half> &a, int lda,
1866     const DeviceMemory<Eigen::half> &b, int ldb,
1867     const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c,
1868     int ldc, blas::ComputationType computation_type,
1869     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
1870   LOG(ERROR)
1871       << "rocBLAS does not currently support the GEMMwithAlgorithm operation "
1872       << "for the \"half\" datatype";
1873   return false;
1874 }
1875 
DoBlasGemmWithAlgorithm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<float> & alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,const HostOrDeviceScalar<float> & beta,DeviceMemory<float> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)1876 bool ROCMBlas::DoBlasGemmWithAlgorithm(
1877     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1878     uint64 n, uint64 k, const HostOrDeviceScalar<float> &alpha,
1879     const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,
1880     int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c,
1881     int ldc, blas::ComputationType computation_type,
1882     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
1883   LOG(ERROR)
1884       << "rocBLAS does not currently support the GEMMwithAlgorithm operation "
1885       << "for the \"float\" datatype";
1886   return false;
1887 }
1888 
DoBlasGemmWithAlgorithm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<double> & alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,const HostOrDeviceScalar<double> & beta,DeviceMemory<double> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)1889 bool ROCMBlas::DoBlasGemmWithAlgorithm(
1890     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1891     uint64 n, uint64 k, const HostOrDeviceScalar<double> &alpha,
1892     const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,
1893     int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c,
1894     int ldc, blas::ComputationType computation_type,
1895     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
1896   LOG(ERROR)
1897       << "rocBLAS does not currently support the GEMMwithAlgorithm operation "
1898       << "for the \"double\" datatype";
1899   return false;
1900 }
1901 
DoBlasGemmWithAlgorithm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<std::complex<float>> & alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,const HostOrDeviceScalar<std::complex<float>> & beta,DeviceMemory<std::complex<float>> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)1902 bool ROCMBlas::DoBlasGemmWithAlgorithm(
1903     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1904     uint64 n, uint64 k, const HostOrDeviceScalar<std::complex<float>> &alpha,
1905     const DeviceMemory<std::complex<float>> &a, int lda,
1906     const DeviceMemory<std::complex<float>> &b, int ldb,
1907     const HostOrDeviceScalar<std::complex<float>> &beta,
1908     DeviceMemory<std::complex<float>> *c, int ldc,
1909     blas::ComputationType computation_type, blas::AlgorithmType algorithm,
1910     blas::ProfileResult *output_profile_result) {
1911   LOG(ERROR)
1912       << "rocBLAS does not currently support the GEMMwithAlgorithm operation "
1913       << "for the \"complex<float>\" datatype";
1914   return false;
1915 }
1916 
DoBlasGemmWithAlgorithm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<std::complex<double>> & alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,const HostOrDeviceScalar<std::complex<double>> & beta,DeviceMemory<std::complex<double>> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)1917 bool ROCMBlas::DoBlasGemmWithAlgorithm(
1918     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1919     uint64 n, uint64 k, const HostOrDeviceScalar<std::complex<double>> &alpha,
1920     const DeviceMemory<std::complex<double>> &a, int lda,
1921     const DeviceMemory<std::complex<double>> &b, int ldb,
1922     const HostOrDeviceScalar<std::complex<double>> &beta,
1923     DeviceMemory<std::complex<double>> *c, int ldc,
1924     blas::ComputationType computation_type, blas::AlgorithmType algorithm,
1925     blas::ProfileResult *output_profile_result) {
1926   LOG(ERROR)
1927       << "rocBLAS does not currently support the GEMMwithAlgorithm operation "
1928       << "for the \"complex<double>\" datatype";
1929   return false;
1930 }
1931 
1932 // This copies from source memory: raw_ptrs[i] to target memory:
1933 // device_memory_ptr at the interval of matrix_byte_size, or vice versa.
1934 // The below algorithm tries to minimize the number of memcpy by consolidating
1935 // neighboring memcpy into a single request
1936 template <typename MAPPED_T>
ReorganizeMemory(Stream * stream,DeviceMemory<MAPPED_T> * device_memory,const std::vector<MAPPED_T * > & raw_ptrs,int batch_count,uint64_t batch_stride,bool gather)1937 port::Status ReorganizeMemory(Stream *stream,
1938                               DeviceMemory<MAPPED_T> *device_memory,
1939                               const std::vector<MAPPED_T *> &raw_ptrs,
1940                               int batch_count, uint64_t batch_stride,
1941                               bool gather) {
1942   assert(batch_count > 0);
1943   char *device_memory_ptr = static_cast<char *>(device_memory->opaque());
1944   char *src_ptr = reinterpret_cast<char *>(raw_ptrs[0]);
1945   char *dst_ptr = device_memory_ptr;
1946   size_t matrix_byte_size = batch_stride * sizeof(MAPPED_T);
1947   uint64_t cur_stride_size = matrix_byte_size;
1948 
1949   for (int i = 1; i < batch_count; ++i) {
1950     if (reinterpret_cast<char *>(raw_ptrs[i]) == src_ptr + cur_stride_size) {
1951       cur_stride_size += matrix_byte_size;
1952     } else {
1953       DeviceMemoryBase src_mem = DeviceMemoryBase(src_ptr, cur_stride_size);
1954       DeviceMemoryBase target_mem = DeviceMemoryBase(dst_ptr, cur_stride_size);
1955       bool a_status =
1956           gather
1957               ? stream->ThenMemcpy(&target_mem, src_mem, cur_stride_size).ok()
1958               : stream->ThenMemcpy(&src_mem, target_mem, cur_stride_size).ok();
1959       if (!a_status) {
1960         return port::Status(
1961             port::error::INTERNAL,
1962             "failed to copy device memory in ROCMBlas::DoBlasGemmBatched");
1963       }
1964       src_ptr = reinterpret_cast<char *>(raw_ptrs[i]);
1965       dst_ptr = device_memory_ptr + i * matrix_byte_size;
1966       cur_stride_size = matrix_byte_size;
1967     }
1968   }
1969 
1970   DeviceMemoryBase src_mem = DeviceMemoryBase(src_ptr, cur_stride_size);
1971   DeviceMemoryBase target_mem = DeviceMemoryBase(dst_ptr, cur_stride_size);
1972   bool a_status =
1973       gather ? stream->ThenMemcpy(&target_mem, src_mem, cur_stride_size).ok()
1974              : stream->ThenMemcpy(&src_mem, target_mem, cur_stride_size).ok();
1975   if (!a_status)
1976     return port::Status(
1977         port::error::INTERNAL,
1978         "failed to copy device memory in ROCMBlas::DoBlasGemmBatched");
1979   return port::Status::OK();
1980 }
1981 
1982 template <typename T>
AllocateStridedBuffer(const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type * > & raw_ptrs,int batch_count,uint64_t batch_stride,ScratchAllocator * scratch_allocator,Stream * stream,std::unique_ptr<TemporaryDeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type>> * temp_memory,DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type> * device_memory,bool copy_data,bool & reallocated)1983 port::Status ROCMBlas::AllocateStridedBuffer(
1984     const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type *>
1985         &raw_ptrs,
1986     int batch_count, uint64_t batch_stride, ScratchAllocator *scratch_allocator,
1987     Stream *stream,
1988     std::unique_ptr<TemporaryDeviceMemory<
1989         typename RocBlasTypeConversionHelper<T>::mapped_type>> *temp_memory,
1990     DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type>
1991         *device_memory,
1992     bool copy_data, bool &reallocated) {
1993   assert(device_memory != nullptr);
1994 
1995   using MAPPED_T = typename RocBlasTypeConversionHelper<T>::mapped_type;
1996 
1997   bool needs_allocate_strided = false;
1998   for (int i = 1; i < batch_count; ++i) {
1999     uint64_t tmp_batch_stride = raw_ptrs[i] - raw_ptrs[i - 1];
2000     if (tmp_batch_stride != batch_stride) {
2001       needs_allocate_strided = true;
2002       break;
2003     }
2004   }
2005 
2006   size_t matrix_byte_size = batch_stride * sizeof(MAPPED_T);
2007   size_t matrix_batch_byte_size = matrix_byte_size * batch_count;
2008 
2009   // No need to do re-allocation, take the short cut and return
2010   if (!needs_allocate_strided) {
2011     *device_memory = DeviceMemory<MAPPED_T>(
2012         DeviceMemoryBase(raw_ptrs[0], matrix_batch_byte_size));
2013     reallocated = false;
2014     return port::Status::OK();
2015   }
2016 
2017   if (scratch_allocator != nullptr) {
2018     SE_ASSIGN_OR_RETURN(
2019         DeviceMemory<uint8> batch_matrix_bytes,
2020         scratch_allocator->AllocateBytes(matrix_batch_byte_size));
2021     *device_memory = DeviceMemory<MAPPED_T>(batch_matrix_bytes);
2022   } else {
2023     assert(temp_memory != nullptr);
2024     SE_ASSIGN_OR_RETURN(*temp_memory, stream->AllocateTemporaryArray<MAPPED_T>(
2025                                           matrix_batch_byte_size));
2026     *device_memory =
2027         DeviceMemory<MAPPED_T>(*(*temp_memory)->mutable_device_memory());
2028   }
2029 
2030   reallocated = true;
2031 
2032   if (copy_data)
2033     return ReorganizeMemory(stream, device_memory, raw_ptrs, batch_count,
2034                             batch_stride, true);
2035   return port::Status::OK();
2036 }
2037 
2038 template <typename T, typename FuncT>
DoBlasGemmBatchedInternal(FuncT rocblas_func,Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,T alpha,const port::ArraySlice<DeviceMemory<T> * > & a_ptrs_to_wrappers,int lda,const port::ArraySlice<DeviceMemory<T> * > & b_ptrs_to_wrappers,int ldb,T beta,const port::ArraySlice<DeviceMemory<T> * > & c_ptrs_to_wrappers,int ldc,int batch_count,ScratchAllocator * scratch_allocator)2039 port::Status ROCMBlas::DoBlasGemmBatchedInternal(
2040     FuncT rocblas_func, Stream *stream, blas::Transpose transa,
2041     blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha,
2042     const port::ArraySlice<DeviceMemory<T> *> &a_ptrs_to_wrappers, int lda,
2043     const port::ArraySlice<DeviceMemory<T> *> &b_ptrs_to_wrappers, int ldb,
2044     T beta, const port::ArraySlice<DeviceMemory<T> *> &c_ptrs_to_wrappers,
2045     int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
2046   using MAPPED_T = typename RocBlasTypeConversionHelper<T>::mapped_type;
2047 
2048   // Sanity checks before making any further progress
2049   uint64_t batch_stride_a = 0;
2050   uint64_t batch_stride_b = 0;
2051   uint64_t batch_stride_c = 0;
2052 
2053   assert(ldc >= m);
2054   batch_stride_c = ldc * n;
2055 
2056   if (ROCMBlasTranspose(transa) == rocblas_operation_none) {
2057     assert(lda >= m);
2058     batch_stride_a = lda * k;
2059   } else {
2060     assert(lda >= k);
2061     batch_stride_a = lda * m;
2062   }
2063 
2064   if (ROCMBlasTranspose(transb) == rocblas_operation_none) {
2065     assert(ldb >= k);
2066     batch_stride_b = ldb * n;
2067   } else {
2068     assert(ldb >= n);
2069     batch_stride_b = ldb * k;
2070   }
2071 
2072   // Allocate local vectors to hold device pointers to matrices
2073   std::vector<MAPPED_T *> a_raw_ptrs, b_raw_ptrs, c_raw_ptrs;
2074   for (int i = 0; i < batch_count; ++i) {
2075     // static_cast does work when converting Eigen::half* to rocblas_half*,
2076     // hence the use of reinterpret_cast
2077     a_raw_ptrs.push_back(
2078         reinterpret_cast<MAPPED_T *>(a_ptrs_to_wrappers[i]->opaque()));
2079     b_raw_ptrs.push_back(
2080         reinterpret_cast<MAPPED_T *>(b_ptrs_to_wrappers[i]->opaque()));
2081     c_raw_ptrs.push_back(
2082         reinterpret_cast<MAPPED_T *>(c_ptrs_to_wrappers[i]->opaque()));
2083   }
2084 
2085   DeviceMemory<MAPPED_T> a;
2086   // Make sure the temporary memory are in-scope before the function returns
2087   std::unique_ptr<TemporaryDeviceMemory<MAPPED_T>> a_temp;
2088   bool reallocated_a, reallocated_b, reallocated_c;
2089   port::Status a_allocation_status = AllocateStridedBuffer<T>(
2090       a_raw_ptrs, batch_count, batch_stride_a, scratch_allocator, stream,
2091       &a_temp, &a, true, reallocated_a);
2092   if (a_allocation_status != port::Status::OK()) {
2093     return a_allocation_status;
2094   }
2095 
2096   DeviceMemory<MAPPED_T> b;
2097   std::unique_ptr<TemporaryDeviceMemory<MAPPED_T>> b_temp;
2098   port::Status b_allocation_status = AllocateStridedBuffer<T>(
2099       b_raw_ptrs, batch_count, batch_stride_b, scratch_allocator, stream,
2100       &b_temp, &b, true, reallocated_b);
2101   if (b_allocation_status != port::Status::OK()) {
2102     return b_allocation_status;
2103   }
2104 
2105   DeviceMemory<MAPPED_T> c;
2106   std::unique_ptr<TemporaryDeviceMemory<MAPPED_T>> c_temp;
2107   port::Status c_allocation_status = AllocateStridedBuffer<T>(
2108       c_raw_ptrs, batch_count, batch_stride_c, scratch_allocator, stream,
2109       &c_temp, &c, true, reallocated_c);  // can disable copy if beta=0
2110   if (c_allocation_status != port::Status::OK()) {
2111     return c_allocation_status;
2112   }
2113 
2114   MAPPED_T *alpha_ptr = reinterpret_cast<MAPPED_T *>(&alpha);
2115   MAPPED_T *beta_ptr = reinterpret_cast<MAPPED_T *>(&beta);
2116 
2117   bool ok;
2118   ok = DoBlasInternal(rocblas_func, stream, /* pointer_mode_host = */ true,
2119                       ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m,
2120                       n, k, GpuComplex(alpha_ptr), GpuMemory(a), lda,
2121                       batch_stride_a, GpuMemory(b), ldb, batch_stride_b,
2122                       GpuComplex(beta_ptr), GpuMemoryMutable(&c), ldc,
2123                       batch_stride_c, batch_count);
2124   if (!ok)
2125     return port::Status(port::error::INTERNAL,
2126                         "failed BLAS call, see log for details");
2127   if (reallocated_c)
2128     return ReorganizeMemory(stream, &c, c_raw_ptrs, batch_count, batch_stride_c,
2129                             false);
2130   return port::Status::OK();
2131 }
2132 
DoBlasGemmBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<Eigen::half> * > & a,int lda,const port::ArraySlice<DeviceMemory<Eigen::half> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<Eigen::half> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)2133 bool ROCMBlas::DoBlasGemmBatched(
2134     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2135     uint64 n, uint64 k, float alpha,
2136     const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
2137     const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
2138     const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
2139     int batch_count, ScratchAllocator *scratch_allocator) {
2140   blas_log("DoBlasGemmBatched");
2141   const Eigen::half alpha_half(alpha);
2142   const Eigen::half beta_half(beta);
2143 
2144   port::Status status = DoBlasGemmBatchedInternal(
2145       wrap::rocblas_hgemm_strided_batched, stream, transa, transb, m, n, k,
2146       alpha_half, a, lda, b, ldb, beta_half, c, ldc, batch_count,
2147       scratch_allocator);
2148   if (!status.ok()) {
2149     LOG(ERROR) << status;
2150   }
2151 
2152   return status.ok();
2153 }
2154 
DoBlasGemmBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<float> * > & a_array,int lda,const port::ArraySlice<DeviceMemory<float> * > & b_array,int ldb,float beta,const port::ArraySlice<DeviceMemory<float> * > & c_array,int ldc,int batch_count,ScratchAllocator * scratch_allocator)2155 bool ROCMBlas::DoBlasGemmBatched(
2156     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2157     uint64 n, uint64 k, float alpha,
2158     const port::ArraySlice<DeviceMemory<float> *> &a_array, int lda,
2159     const port::ArraySlice<DeviceMemory<float> *> &b_array, int ldb, float beta,
2160     const port::ArraySlice<DeviceMemory<float> *> &c_array, int ldc,
2161     int batch_count, ScratchAllocator *scratch_allocator) {
2162   blas_log("DoBlasGemmBatched");
2163   port::Status status = DoBlasGemmBatchedInternal(
2164       wrap::rocblas_sgemm_strided_batched, stream, transa, transb, m, n, k,
2165       alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count,
2166       scratch_allocator);
2167   if (!status.ok()) {
2168     LOG(ERROR) << status;
2169   }
2170   return status.ok();
2171 }
2172 
DoBlasGemmBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const port::ArraySlice<DeviceMemory<double> * > & a_array,int lda,const port::ArraySlice<DeviceMemory<double> * > & b_array,int ldb,double beta,const port::ArraySlice<DeviceMemory<double> * > & c_array,int ldc,int batch_count,ScratchAllocator * scratch_allocator)2173 bool ROCMBlas::DoBlasGemmBatched(
2174     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2175     uint64 n, uint64 k, double alpha,
2176     const port::ArraySlice<DeviceMemory<double> *> &a_array, int lda,
2177     const port::ArraySlice<DeviceMemory<double> *> &b_array, int ldb,
2178     double beta, const port::ArraySlice<DeviceMemory<double> *> &c_array,
2179     int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
2180   blas_log("DoBlasGemmBatched");
2181   port::Status status = DoBlasGemmBatchedInternal(
2182       wrap::rocblas_dgemm_strided_batched, stream, transa, transb, m, n, k,
2183       alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count,
2184       scratch_allocator);
2185   if (!status.ok()) {
2186     LOG(ERROR) << status;
2187   }
2188   return status.ok();
2189 }
2190 
DoBlasGemmBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & a_array,int lda,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & b_array,int ldb,std::complex<float> beta,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & c_array,int ldc,int batch_count,ScratchAllocator * scratch_allocator)2191 bool ROCMBlas::DoBlasGemmBatched(
2192     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2193     uint64 n, uint64 k, std::complex<float> alpha,
2194     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a_array,
2195     int lda,
2196     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b_array,
2197     int ldb, std::complex<float> beta,
2198     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c_array,
2199     int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
2200   blas_log("DoBlasGemmBatched");
2201   port::Status status = DoBlasGemmBatchedInternal(
2202       wrap::rocblas_cgemm_strided_batched, stream, transa, transb, m, n, k,
2203       alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count,
2204       scratch_allocator);
2205   if (!status.ok()) {
2206     LOG(ERROR) << status;
2207   }
2208   return status.ok();
2209 }
2210 
DoBlasGemmBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & a_array,int lda,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & b_array,int ldb,std::complex<double> beta,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & c_array,int ldc,int batch_count,ScratchAllocator * scratch_allocator)2211 bool ROCMBlas::DoBlasGemmBatched(
2212     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2213     uint64 n, uint64 k, std::complex<double> alpha,
2214     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a_array,
2215     int lda,
2216     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b_array,
2217     int ldb, std::complex<double> beta,
2218     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c_array,
2219     int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
2220   blas_log("DoBlasGemmBatched");
2221   port::Status status = DoBlasGemmBatchedInternal(
2222       wrap::rocblas_zgemm_strided_batched, stream, transa, transb, m, n, k,
2223       alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count,
2224       scratch_allocator);
2225   if (!status.ok()) {
2226     LOG(ERROR) << status;
2227   }
2228   return status.ok();
2229 }
2230 
DoBlasHemm(Stream * stream,blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)2231 bool ROCMBlas::DoBlasHemm(Stream *stream, blas::Side side,
2232                           blas::UpperLower uplo, uint64 m, uint64 n,
2233                           std::complex<float> alpha,
2234                           const DeviceMemory<std::complex<float>> &a, int lda,
2235                           const DeviceMemory<std::complex<float>> &b, int ldb,
2236                           std::complex<float> beta,
2237                           DeviceMemory<std::complex<float>> *c, int ldc) {
2238   return DoBlasInternal(wrap::rocblas_chemm, stream,
2239                         /* pointer_mode_host = */ true, ROCMBlasSide(side),
2240                         ROCMBlasUpperLower(uplo), m, n, complex_cast(alpha),
2241                         complex_cast(a), lda, complex_cast(b), ldb,
2242                         complex_cast(beta), complex_cast(c), ldc);
2243 }
2244 
DoBlasHemm(Stream * stream,blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)2245 bool ROCMBlas::DoBlasHemm(Stream *stream, blas::Side side,
2246                           blas::UpperLower uplo, uint64 m, uint64 n,
2247                           std::complex<double> alpha,
2248                           const DeviceMemory<std::complex<double>> &a, int lda,
2249                           const DeviceMemory<std::complex<double>> &b, int ldb,
2250                           std::complex<double> beta,
2251                           DeviceMemory<std::complex<double>> *c, int ldc) {
2252   return DoBlasInternal(wrap::rocblas_zhemm, stream,
2253                         /* pointer_mode_host = */ true, ROCMBlasSide(side),
2254                         ROCMBlasUpperLower(uplo), m, n, complex_cast(alpha),
2255                         complex_cast(a), lda, complex_cast(b), ldb,
2256                         complex_cast(beta), complex_cast(c), ldc);
2257 }
2258 
DoBlasHerk(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<std::complex<float>> & a,int lda,float beta,DeviceMemory<std::complex<float>> * c,int ldc)2259 bool ROCMBlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo,
2260                           blas::Transpose trans, uint64 n, uint64 k,
2261                           float alpha,
2262                           const DeviceMemory<std::complex<float>> &a, int lda,
2263                           float beta, DeviceMemory<std::complex<float>> *c,
2264                           int ldc) {
2265   return DoBlasInternal(wrap::rocblas_cherk, stream,
2266                         /* pointer_mode_host = */ true,
2267                         ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n,
2268                         k, complex_cast(alpha), complex_cast(a), lda,
2269                         complex_cast(beta), complex_cast(c), ldc);
2270 }
2271 
DoBlasHerk(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<std::complex<double>> & a,int lda,double beta,DeviceMemory<std::complex<double>> * c,int ldc)2272 bool ROCMBlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo,
2273                           blas::Transpose trans, uint64 n, uint64 k,
2274                           double alpha,
2275                           const DeviceMemory<std::complex<double>> &a, int lda,
2276                           double beta, DeviceMemory<std::complex<double>> *c,
2277                           int ldc) {
2278   return DoBlasInternal(wrap::rocblas_zherk, stream,
2279                         /* pointer_mode_host = */ true,
2280                         ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n,
2281                         k, complex_cast(alpha), complex_cast(a), lda,
2282                         complex_cast(beta), complex_cast(c), ldc);
2283 }
2284 
DoBlasHer2k(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,float beta,DeviceMemory<std::complex<float>> * c,int ldc)2285 bool ROCMBlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
2286                            blas::Transpose trans, uint64 n, uint64 k,
2287                            std::complex<float> alpha,
2288                            const DeviceMemory<std::complex<float>> &a, int lda,
2289                            const DeviceMemory<std::complex<float>> &b, int ldb,
2290                            float beta, DeviceMemory<std::complex<float>> *c,
2291                            int ldc) {
2292   return DoBlasInternal(
2293       wrap::rocblas_cher2k, stream, /* pointer_mode_host = */ true,
2294       ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k,
2295       complex_cast(alpha), complex_cast(a), lda, complex_cast(b), ldb,
2296       complex_cast(beta), complex_cast(c), ldc);
2297 }
2298 
DoBlasHer2k(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,double beta,DeviceMemory<std::complex<double>> * c,int ldc)2299 bool ROCMBlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
2300                            blas::Transpose trans, uint64 n, uint64 k,
2301                            std::complex<double> alpha,
2302                            const DeviceMemory<std::complex<double>> &a, int lda,
2303                            const DeviceMemory<std::complex<double>> &b, int ldb,
2304                            double beta, DeviceMemory<std::complex<double>> *c,
2305                            int ldc) {
2306   return DoBlasInternal(
2307       wrap::rocblas_zher2k, stream, /* pointer_mode_host = */ true,
2308       ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k,
2309       complex_cast(alpha), complex_cast(a), lda, complex_cast(b), ldb,
2310       complex_cast(beta), complex_cast(c), ldc);
2311 }
2312 
DoBlasSymm(Stream * stream,blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)2313 bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side,
2314                           blas::UpperLower uplo, uint64 m, uint64 n,
2315                           float alpha, const DeviceMemory<float> &a, int lda,
2316                           const DeviceMemory<float> &b, int ldb, float beta,
2317                           DeviceMemory<float> *c, int ldc) {
2318   return DoBlasInternal(
2319       wrap::rocblas_ssymm, stream, /* pointer_mode_host = */ true,
2320       ROCMBlasSide(side), ROCMBlasUpperLower(uplo), m, n, &alpha, GpuMemory(a),
2321       lda, GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc);
2322 }
2323 
DoBlasSymm(Stream * stream,blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)2324 bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side,
2325                           blas::UpperLower uplo, uint64 m, uint64 n,
2326                           double alpha, const DeviceMemory<double> &a, int lda,
2327                           const DeviceMemory<double> &b, int ldb, double beta,
2328                           DeviceMemory<double> *c, int ldc) {
2329   return DoBlasInternal(
2330       wrap::rocblas_dsymm, stream, /* pointer_mode_host = */ true,
2331       ROCMBlasSide(side), ROCMBlasUpperLower(uplo), m, n, &alpha, GpuMemory(a),
2332       lda, GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc);
2333 }
2334 
DoBlasSymm(Stream * stream,blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)2335 bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side,
2336                           blas::UpperLower uplo, uint64 m, uint64 n,
2337                           std::complex<float> alpha,
2338                           const DeviceMemory<std::complex<float>> &a, int lda,
2339                           const DeviceMemory<std::complex<float>> &b, int ldb,
2340                           std::complex<float> beta,
2341                           DeviceMemory<std::complex<float>> *c, int ldc) {
2342   return DoBlasInternal(wrap::rocblas_csymm, stream,
2343                         /* pointer_mode_host = */ true, ROCMBlasSide(side),
2344                         ROCMBlasUpperLower(uplo), m, n, complex_cast(alpha),
2345                         complex_cast(a), lda, complex_cast(b), ldb,
2346                         complex_cast(beta), complex_cast(c), ldc);
2347 }
2348 
DoBlasSymm(Stream * stream,blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)2349 bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side,
2350                           blas::UpperLower uplo, uint64 m, uint64 n,
2351                           std::complex<double> alpha,
2352                           const DeviceMemory<std::complex<double>> &a, int lda,
2353                           const DeviceMemory<std::complex<double>> &b, int ldb,
2354                           std::complex<double> beta,
2355                           DeviceMemory<std::complex<double>> *c, int ldc) {
2356   return DoBlasInternal(wrap::rocblas_zsymm, stream,
2357                         /* pointer_mode_host = */ true, ROCMBlasSide(side),
2358                         ROCMBlasUpperLower(uplo), m, n, complex_cast(alpha),
2359                         complex_cast(a), lda, complex_cast(b), ldb,
2360                         complex_cast(beta), complex_cast(c), ldc);
2361 }
2362 
DoBlasSyrk(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,float beta,DeviceMemory<float> * c,int ldc)2363 bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
2364                           blas::Transpose trans, uint64 n, uint64 k,
2365                           float alpha, const DeviceMemory<float> &a, int lda,
2366                           float beta, DeviceMemory<float> *c, int ldc) {
2367   return DoBlasInternal(
2368       wrap::rocblas_ssyrk, stream, /* pointer_mode_host = */ true,
2369       ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k, &alpha,
2370       GpuMemory(a), lda, &beta, GpuMemoryMutable(c), ldc);
2371 }
2372 
DoBlasSyrk(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,double beta,DeviceMemory<double> * c,int ldc)2373 bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
2374                           blas::Transpose trans, uint64 n, uint64 k,
2375                           double alpha, const DeviceMemory<double> &a, int lda,
2376                           double beta, DeviceMemory<double> *c, int ldc) {
2377   return DoBlasInternal(
2378       wrap::rocblas_dsyrk, stream, /* pointer_mode_host = */ true,
2379       ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k, &alpha,
2380       GpuMemory(a), lda, &beta, GpuMemoryMutable(c), ldc);
2381 }
2382 
DoBlasSyrk(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)2383 bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
2384                           blas::Transpose trans, uint64 n, uint64 k,
2385                           std::complex<float> alpha,
2386                           const DeviceMemory<std::complex<float>> &a, int lda,
2387                           std::complex<float> beta,
2388                           DeviceMemory<std::complex<float>> *c, int ldc) {
2389   return DoBlasInternal(wrap::rocblas_csyrk, stream,
2390                         /* pointer_mode_host = */ true,
2391                         ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n,
2392                         k, complex_cast(alpha), complex_cast(a), lda,
2393                         complex_cast(beta), complex_cast(c), ldc);
2394 }
2395 
DoBlasSyrk(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)2396 bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
2397                           blas::Transpose trans, uint64 n, uint64 k,
2398                           std::complex<double> alpha,
2399                           const DeviceMemory<std::complex<double>> &a, int lda,
2400                           std::complex<double> beta,
2401                           DeviceMemory<std::complex<double>> *c, int ldc) {
2402   return DoBlasInternal(wrap::rocblas_zsyrk, stream,
2403                         /* pointer_mode_host = */ true,
2404                         ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n,
2405                         k, complex_cast(alpha), complex_cast(a), lda,
2406                         complex_cast(beta), complex_cast(c), ldc);
2407 }
2408 
DoBlasSyr2k(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)2409 bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
2410                            blas::Transpose trans, uint64 n, uint64 k,
2411                            float alpha, const DeviceMemory<float> &a, int lda,
2412                            const DeviceMemory<float> &b, int ldb, float beta,
2413                            DeviceMemory<float> *c, int ldc) {
2414   return DoBlasInternal(
2415       wrap::rocblas_ssyr2k, stream, /* pointer_mode_host = */ true,
2416       ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k, &alpha,
2417       GpuMemory(a), lda, GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc);
2418 }
2419 
DoBlasSyr2k(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)2420 bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
2421                            blas::Transpose trans, uint64 n, uint64 k,
2422                            double alpha, const DeviceMemory<double> &a, int lda,
2423                            const DeviceMemory<double> &b, int ldb, double beta,
2424                            DeviceMemory<double> *c, int ldc) {
2425   return DoBlasInternal(
2426       wrap::rocblas_dsyr2k, stream, /* pointer_mode_host = */ true,
2427       ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k, &alpha,
2428       GpuMemory(a), lda, GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc);
2429 }
2430 
DoBlasSyr2k(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)2431 bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
2432                            blas::Transpose trans, uint64 n, uint64 k,
2433                            std::complex<float> alpha,
2434                            const DeviceMemory<std::complex<float>> &a, int lda,
2435                            const DeviceMemory<std::complex<float>> &b, int ldb,
2436                            std::complex<float> beta,
2437                            DeviceMemory<std::complex<float>> *c, int ldc) {
2438   return DoBlasInternal(
2439       wrap::rocblas_csyr2k, stream, /* pointer_mode_host = */ true,
2440       ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k,
2441       complex_cast(alpha), complex_cast(a), lda, complex_cast(b), ldb,
2442       complex_cast(beta), complex_cast(c), ldc);
2443 }
2444 
DoBlasSyr2k(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)2445 bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
2446                            blas::Transpose trans, uint64 n, uint64 k,
2447                            std::complex<double> alpha,
2448                            const DeviceMemory<std::complex<double>> &a, int lda,
2449                            const DeviceMemory<std::complex<double>> &b, int ldb,
2450                            std::complex<double> beta,
2451                            DeviceMemory<std::complex<double>> *c, int ldc) {
2452   return DoBlasInternal(
2453       wrap::rocblas_zsyr2k, stream, /* pointer_mode_host = */ true,
2454       ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k,
2455       complex_cast(alpha), complex_cast(a), lda, complex_cast(b), ldb,
2456       complex_cast(beta), complex_cast(c), ldc);
2457 }
2458 
DoBlasTrmm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * b,int ldb)2459 bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side,
2460                           blas::UpperLower uplo, blas::Transpose transa,
2461                           blas::Diagonal diag, uint64 m, uint64 n, float alpha,
2462                           const DeviceMemory<float> &a, int lda,
2463                           DeviceMemory<float> *b, int ldb) {
2464   return DoBlasInternal(wrap::rocblas_strmm, stream,
2465                         /* pointer_mode_host = */ true, ROCMBlasSide(side),
2466                         ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
2467                         ROCMBlasDiagonal(diag), m, n, &alpha, GpuMemory(a), lda,
2468                         GpuMemoryMutable(b), ldb);
2469 }
2470 
DoBlasTrmm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * b,int ldb)2471 bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side,
2472                           blas::UpperLower uplo, blas::Transpose transa,
2473                           blas::Diagonal diag, uint64 m, uint64 n, double alpha,
2474                           const DeviceMemory<double> &a, int lda,
2475                           DeviceMemory<double> *b, int ldb) {
2476   return DoBlasInternal(wrap::rocblas_dtrmm, stream,
2477                         /* pointer_mode_host = */ true, ROCMBlasSide(side),
2478                         ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
2479                         ROCMBlasDiagonal(diag), m, n, &alpha, GpuMemory(a), lda,
2480                         GpuMemoryMutable(b), ldb);
2481 }
2482 
DoBlasTrmm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * b,int ldb)2483 bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side,
2484                           blas::UpperLower uplo, blas::Transpose transa,
2485                           blas::Diagonal diag, uint64 m, uint64 n,
2486                           std::complex<float> alpha,
2487                           const DeviceMemory<std::complex<float>> &a, int lda,
2488                           DeviceMemory<std::complex<float>> *b, int ldb) {
2489   return DoBlasInternal(wrap::rocblas_ctrmm, stream,
2490                         /* pointer_mode_host = */ true, ROCMBlasSide(side),
2491                         ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
2492                         ROCMBlasDiagonal(diag), m, n, complex_cast(alpha),
2493                         complex_cast(a), lda, complex_cast(b), ldb);
2494 }
2495 
DoBlasTrmm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * b,int ldb)2496 bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side,
2497                           blas::UpperLower uplo, blas::Transpose transa,
2498                           blas::Diagonal diag, uint64 m, uint64 n,
2499                           std::complex<double> alpha,
2500                           const DeviceMemory<std::complex<double>> &a, int lda,
2501                           DeviceMemory<std::complex<double>> *b, int ldb) {
2502   return DoBlasInternal(wrap::rocblas_ztrmm, stream,
2503                         /* pointer_mode_host = */ true, ROCMBlasSide(side),
2504                         ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
2505                         ROCMBlasDiagonal(diag), m, n, complex_cast(alpha),
2506                         complex_cast(a), lda, complex_cast(b), ldb);
2507 }
2508 
DoBlasTrsm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * b,int ldb)2509 bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
2510                           blas::UpperLower uplo, blas::Transpose transa,
2511                           blas::Diagonal diag, uint64 m, uint64 n, float alpha,
2512                           const DeviceMemory<float> &a, int lda,
2513                           DeviceMemory<float> *b, int ldb) {
2514   blas_log("DoBlasTrsm");
2515   return DoBlasInternal(wrap::rocblas_strsm, stream,
2516                         /* pointer_mode_host = */ true, ROCMBlasSide(side),
2517                         ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
2518                         ROCMBlasDiagonal(diag), m, n, &alpha, GpuMemory(a), lda,
2519                         GpuMemoryMutable(b), ldb);
2520 }
2521 
DoBlasTrsm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * b,int ldb)2522 bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
2523                           blas::UpperLower uplo, blas::Transpose transa,
2524                           blas::Diagonal diag, uint64 m, uint64 n, double alpha,
2525                           const DeviceMemory<double> &a, int lda,
2526                           DeviceMemory<double> *b, int ldb) {
2527   blas_log("DoBlasTrsm");
2528   return DoBlasInternal(wrap::rocblas_dtrsm, stream,
2529                         /* pointer_mode_host = */ true, ROCMBlasSide(side),
2530                         ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
2531                         ROCMBlasDiagonal(diag), m, n, &alpha, GpuMemory(a), lda,
2532                         GpuMemoryMutable(b), ldb);
2533 }
2534 
DoBlasTrsm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * b,int ldb)2535 bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
2536                           blas::UpperLower uplo, blas::Transpose transa,
2537                           blas::Diagonal diag, uint64 m, uint64 n,
2538                           std::complex<float> alpha,
2539                           const DeviceMemory<std::complex<float>> &a, int lda,
2540                           DeviceMemory<std::complex<float>> *b, int ldb) {
2541   return DoBlasInternal(wrap::rocblas_ctrsm, stream,
2542                         /* pointer_mode_host = */ true, ROCMBlasSide(side),
2543                         ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
2544                         ROCMBlasDiagonal(diag), m, n, complex_cast(alpha),
2545                         complex_cast(a), lda, complex_cast(b), ldb);
2546 }
2547 
DoBlasTrsm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * b,int ldb)2548 bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
2549                           blas::UpperLower uplo, blas::Transpose transa,
2550                           blas::Diagonal diag, uint64 m, uint64 n,
2551                           std::complex<double> alpha,
2552                           const DeviceMemory<std::complex<double>> &a, int lda,
2553                           DeviceMemory<std::complex<double>> *b, int ldb) {
2554   return DoBlasInternal(wrap::rocblas_ztrsm, stream,
2555                         /* pointer_mode_host = */ true, ROCMBlasSide(side),
2556                         ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
2557                         ROCMBlasDiagonal(diag), m, n, complex_cast(alpha),
2558                         complex_cast(a), lda, complex_cast(b), ldb);
2559 }
2560 
DoBlasGemmStridedBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,int64 stride_a,const DeviceMemory<Eigen::half> & b,int ldb,int64 stride_b,float beta,DeviceMemory<Eigen::half> * c,int ldc,int64 stride_c,int batch_count)2561 bool ROCMBlas::DoBlasGemmStridedBatched(
2562     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2563     uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
2564     int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb,
2565     int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc,
2566     int64 stride_c, int batch_count) {
2567   blas_log("DoBlasGemmStridedBatched");
2568   const Eigen::half alpha_half(alpha);
2569   const Eigen::half beta_half(beta);
2570 
2571   return DoBlasInternal(
2572       wrap::rocblas_hgemm_strided_batched, stream,
2573       false, /* pointer_mode_host */
2574       ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
2575       reinterpret_cast<const rocblas_half *>(&alpha_half),
2576       reinterpret_cast<const rocblas_half *>(GpuMemory(a)), lda, stride_a,
2577       reinterpret_cast<const rocblas_half *>(GpuMemory(b)), ldb, stride_b,
2578       reinterpret_cast<const rocblas_half *>(&beta_half),
2579       reinterpret_cast<rocblas_half *>(GpuMemoryMutable(c)), ldc, stride_c,
2580       batch_count);
2581 }
2582 
DoBlasGemmStridedBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,int64 stride_a,const DeviceMemory<float> & b,int ldb,int64 stride_b,float beta,DeviceMemory<float> * c,int ldc,int64 stride_c,int batch_count)2583 bool ROCMBlas::DoBlasGemmStridedBatched(
2584     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2585     uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
2586     int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
2587     float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
2588     int batch_count) {
2589   VLOG(1) << absl::StreamFormat(
2590       "doing rocBLAS SGEMM Strided Batched<float>: at=%d bt=%d m=%u n=%u "
2591       "k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
2592       "c=%p ldc=%d",
2593       static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
2594       a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
2595   return DoBlasInternal(wrap::rocblas_sgemm_strided_batched, stream,
2596                         false, /* pointer_mode_host */
2597                         ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m,
2598                         n, k, &alpha, GpuMemory(a), lda, stride_a, GpuMemory(b),
2599                         ldb, stride_b, &beta, GpuMemoryMutable(c), ldc,
2600                         stride_c, batch_count);
2601 }
DoBlasGemmStridedBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,int64 stride_a,const DeviceMemory<double> & b,int ldb,int64 stride_b,double beta,DeviceMemory<double> * c,int ldc,int64 stride_c,int batch_count)2602 bool ROCMBlas::DoBlasGemmStridedBatched(
2603     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2604     uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
2605     int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
2606     double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
2607     int batch_count) {
2608   VLOG(1) << absl::StreamFormat(
2609       "doing rocBLAS SGEMM Strided Batched<double>: at=%d bt=%d m=%u n=%u "
2610       "k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
2611       "c=%p ldc=%d",
2612       static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
2613       a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
2614   return DoBlasInternal(wrap::rocblas_dgemm_strided_batched, stream,
2615                         false, /* pointer_mode_host */
2616                         ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m,
2617                         n, k, &alpha, GpuMemory(a), lda, stride_a, GpuMemory(b),
2618                         ldb, stride_b, &beta, GpuMemoryMutable(c), ldc,
2619                         stride_c, batch_count);
2620 }
DoBlasGemmStridedBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,int64 stride_a,const DeviceMemory<std::complex<float>> & b,int ldb,int64 stride_b,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc,int64 stride_c,int batch_count)2621 bool ROCMBlas::DoBlasGemmStridedBatched(
2622     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2623     uint64 n, uint64 k, std::complex<float> alpha,
2624     const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
2625     const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
2626     std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
2627     int64 stride_c, int batch_count) {
2628   return DoBlasInternal(wrap::rocblas_cgemm_strided_batched, stream,
2629                         false, /* pointer_mode_host */
2630                         ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m,
2631                         n, k, complex_cast(alpha), complex_cast(a), lda,
2632                         stride_a, complex_cast(b), ldb, stride_b,
2633                         complex_cast(beta), complex_cast(c), ldc, stride_c,
2634                         batch_count);
2635 }
DoBlasGemmStridedBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,int64 stride_a,const DeviceMemory<std::complex<double>> & b,int ldb,int64 stride_b,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc,int64 stride_c,int batch_count)2636 bool ROCMBlas::DoBlasGemmStridedBatched(
2637     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2638     uint64 n, uint64 k, std::complex<double> alpha,
2639     const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
2640     const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
2641     std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
2642     int64 stride_c, int batch_count) {
2643   return DoBlasInternal(wrap::rocblas_zgemm_strided_batched, stream,
2644                         false, /* pointer_mode_host */
2645                         ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m,
2646                         n, k, complex_cast(alpha), complex_cast(a), lda,
2647                         stride_a, complex_cast(b), ldb, stride_b,
2648                         complex_cast(beta), complex_cast(c), ldc, stride_c,
2649                         batch_count);
2650 }
2651 
GetVersion(string * version)2652 port::Status ROCMBlas::GetVersion(string *version) {
2653   return port::UnimplementedError("");
2654 }
2655 
2656 port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams & p)2657 ROCMBlas::CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams &p) {
2658   return port::Status(
2659       port::error::UNIMPLEMENTED,
2660       "CreateBlasLtMatmulPlan is not supported with this version of ROCM");
2661 }
2662 
2663 port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan * plan,size_t max_workspace_size,int max_algorithm_count)2664 ROCMBlas::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
2665                                     size_t max_workspace_size,
2666                                     int max_algorithm_count) {
2667   return port::Status(
2668       port::error::UNIMPLEMENTED,
2669       "GetBlasLtMatmulAlgorithms is not supported with this version of ROCM");
2670 }
2671 
DoBlasLtMatmul(Stream * stream,const blas::IBlasLtMatmulPlan * plan,const HostOrDeviceScalar<void> & alpha,DeviceMemoryBase a,DeviceMemoryBase b,const HostOrDeviceScalar<void> & beta,DeviceMemoryBase c,ScratchAllocator * scratch_allocator,const blas::IBlasLtMatmulAlgorithm * algorithm,DeviceMemoryBase bias,blas::ProfileResult * output_profile_result)2672 bool ROCMBlas::DoBlasLtMatmul(
2673     Stream *stream, const blas::IBlasLtMatmulPlan *plan,
2674     const HostOrDeviceScalar<void> &alpha, DeviceMemoryBase a,
2675     DeviceMemoryBase b, const HostOrDeviceScalar<void> &beta,
2676     DeviceMemoryBase c, ScratchAllocator *scratch_allocator,
2677     const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias,
2678     blas::ProfileResult *output_profile_result) {
2679   return false;
2680 }
2681 
2682 }  // namespace gpu
2683 
initialize_rocblas()2684 void initialize_rocblas() {
2685   auto rocBlasAlreadyRegistered = PluginRegistry::Instance()->HasFactory(
2686       rocm::kROCmPlatformId, PluginKind::kBlas, gpu::kRocBlasPlugin);
2687 
2688   if (!rocBlasAlreadyRegistered) {
2689     port::Status status =
2690         PluginRegistry::Instance()
2691             ->RegisterFactory<PluginRegistry::BlasFactory>(
2692                 rocm::kROCmPlatformId, gpu::kRocBlasPlugin, "rocBLAS",
2693                 [](internal::StreamExecutorInterface *parent)
2694                     -> blas::BlasSupport * {
2695                   gpu::GpuExecutor *rocm_executor =
2696                       dynamic_cast<gpu::GpuExecutor *>(parent);
2697                   if (rocm_executor == nullptr) {
2698                     LOG(ERROR)
2699                         << "Attempting to initialize an instance of the "
2700                            "rocBLAS "
2701                         << "support library with a non-ROCM StreamExecutor";
2702                     return nullptr;
2703                   }
2704 
2705                   gpu::ROCMBlas *blas = new gpu::ROCMBlas(rocm_executor);
2706                   if (!blas->Init()) {
2707                     // Note: Init() will log a more specific error.
2708                     delete blas;
2709                     return nullptr;
2710                   }
2711                   return blas;
2712                 });
2713 
2714     if (!status.ok()) {
2715       LOG(ERROR) << "Unable to register rocBLAS factory: "
2716                  << status.error_message();
2717     }
2718 
2719     PluginRegistry::Instance()->SetDefaultFactory(
2720         rocm::kROCmPlatformId, PluginKind::kBlas, gpu::kRocBlasPlugin);
2721   }
2722 }
2723 
2724 }  // namespace stream_executor
2725 
2726 REGISTER_MODULE_INITIALIZER(register_rocblas,
2727                             { stream_executor::initialize_rocblas(); });
2728