1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "rocm/include/rocblas.h"
17
18 #include "tensorflow/stream_executor/rocm/rocm_blas.h"
19
20 #define EIGEN_USE_GPU
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22
23 #include <assert.h>
24 #include <complex>
25
26 #include "absl/strings/str_cat.h"
27 #include "tensorflow/stream_executor/device_memory.h"
28 #include "tensorflow/stream_executor/gpu/gpu_activation.h"
29 #include "tensorflow/stream_executor/gpu/gpu_executor.h"
30 #include "tensorflow/stream_executor/gpu/gpu_helpers.h"
31 #include "tensorflow/stream_executor/gpu/gpu_stream.h"
32 #include "tensorflow/stream_executor/gpu/gpu_timer.h"
33 #include "tensorflow/stream_executor/lib/env.h"
34 #include "tensorflow/stream_executor/lib/initialize.h"
35 #include "tensorflow/stream_executor/lib/status.h"
36 #include "tensorflow/stream_executor/lib/status_macros.h"
37 #include "tensorflow/stream_executor/lib/stringprintf.h"
38 #include "tensorflow/stream_executor/platform/dso_loader.h"
39 #include "tensorflow/stream_executor/platform/logging.h"
40 #include "tensorflow/stream_executor/platform/port.h"
41 #include "tensorflow/stream_executor/plugin_registry.h"
42 #include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
43 #include "tensorflow/stream_executor/scratch_allocator.h"
44 #include "tensorflow/stream_executor/stream_executor.h"
45
46 namespace stream_executor {
47 namespace gpu {
48
49 PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kRocBlasPlugin);
50
51 namespace wrap {
52
53 #ifdef PLATFORM_GOOGLE
54 #define STREAM_EXECUTOR_ROCBLAS_WRAP(__name) \
55 struct WrapperShim__##__name { \
56 static const char *kName; \
57 template <typename... Args> \
58 rocblas_status operator()(GpuExecutor *parent, Args... args) { \
59 gpu::ScopedActivateExecutorContext sac{parent}; \
60 return ::__name(args...); \
61 } \
62 } __name; \
63 const char *WrapperShim__##__name::kName = #__name;
64
65 #define STREAM_EXECUTOR_ROCBLAS_V2_WRAP(__name) \
66 STREAM_EXECUTOR_ROCBLAS_WRAP(__name)
67
68 #else
69
70 #define STREAM_EXECUTOR_ROCBLAS_WRAP(__name) \
71 struct DynLoadShim__##__name { \
72 static const char *kName; \
73 using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \
74 static void *GetDsoHandle() { \
75 auto s = internal::CachedDsoLoader::GetRocblasDsoHandle(); \
76 return s.ValueOrDie(); \
77 } \
78 static FuncPtrT LoadOrDie() { \
79 void *f; \
80 auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \
81 kName, &f); \
82 CHECK(s.ok()) << "could not find " << kName \
83 << " in rocblas DSO; dlerror: " << s.error_message(); \
84 return reinterpret_cast<FuncPtrT>(f); \
85 } \
86 static FuncPtrT DynLoad() { \
87 static FuncPtrT f = LoadOrDie(); \
88 return f; \
89 } \
90 template <typename... Args> \
91 rocblas_status operator()(GpuExecutor *parent, Args... args) { \
92 gpu::ScopedActivateExecutorContext sac{parent}; \
93 return DynLoad()(args...); \
94 } \
95 } __name; \
96 const char *DynLoadShim__##__name::kName = #__name;
97
98 #define STREAM_EXECUTOR_ROCBLAS_V2_WRAP(__name) \
99 STREAM_EXECUTOR_ROCBLAS_WRAP(__name)
100
101 #endif
102
103 #define ROCBLAS_BLAS_ROUTINE_EACH(__macro) \
104 __macro(rocblas_snrm2) __macro(rocblas_dnrm2) /* __macro(rocblas_scnrm2) \
105 __macro(rocblas_dznrm2) */ \
106 __macro(rocblas_sdot) \
107 __macro(rocblas_ddot) /* __macro(rocblas_cdotu) \
108 __macro(rocblas_cdotc) \
109 __macro(rocblas_zdotu) \
110 __macro(rocblas_zdotc) */ \
111 __macro(rocblas_sscal) \
112 __macro(rocblas_dscal) /* __macro(rocblas_cscal) \
113 __macro(rocblas_csscal) \
114 __macro(rocblas_zscal) \
115 __macro(rocblas_zdscal) */ \
116 __macro(rocblas_saxpy) \
117 __macro(rocblas_daxpy) /* __macro(rocblas_caxpy) \
118 __macro(rocblas_zaxpy) */ \
119 __macro(rocblas_scopy) \
120 __macro(rocblas_dcopy) /* __macro(rocblas_ccopy) \
121 __macro(rocblas_zcopy) */ \
122 __macro(rocblas_sswap) \
123 __macro(rocblas_dswap) /* __macro(rocblas_cswap) \
124 __macro(rocblas_zswap) */ \
125 __macro(rocblas_isamax) \
126 __macro(rocblas_idamax) /* __macro(rocblas_icamax) \
127 __macro(rocblas_izamax) */ \
128 __macro(rocblas_isamin) \
129 __macro(rocblas_idamin) /* __macro(rocblas_icamin) \
130 __macro(rocblas_izamin) */ \
131 __macro(rocblas_sasum) \
132 __macro(rocblas_dasum) /* __macro(rocblas_scasum) \
133 __macro(rocblas_dzasum) \
134 __macro(rocblas_srot) \
135 __macro(rocblas_drot) \
136 __macro(rocblas_crot) \
137 __macro(rocblas_csrot) \
138 __macro(rocblas_zrot) \
139 __macro(rocblas_zdrot) \
140 __macro(rocblas_srotg) \
141 __macro(rocblas_drotg) \
142 __macro(rocblas_Crotg) \
143 __macro(rocblas_crotg) \
144 __macro(rocblas_zrotm) \
145 __macro(rocblas_drotm) \
146 __macro(rocblas_srotmg) \
147 __macro(rocblas_drotmg) */ \
148 __macro(rocblas_sgemv) \
149 __macro(rocblas_dgemv) /* __macro(rocblas_cgemv) \
150 __macro(rocblas_zgemv) \
151 __macro(rocblas_sgbmv) \
152 __macro(rocblas_dgbmv) \
153 __macro(rocblas_cgbmv) \
154 __macro(rocblas_zgbmv) \
155 __macro(rocblas_strmv) \
156 __macro(rocblas_dtrmv) \
157 __macro(rocblas_ctrmv) \
158 __macro(rocblas_ztrmv) \
159 __macro(rocblas_stbmv) \
160 __macro(rocblas_dtbmv) \
161 __macro(rocblas_ctbmv) \
162 __macro(rocblas_ztbmv) \
163 __macro(rocblas_stpmv) \
164 __macro(rocblas_dtpmv) \
165 __macro(rocblas_ctpmv) \
166 __macro(rocblas_ztpmv) \
167 __macro(rocblas_strsv) \
168 __macro(rocblas_dtrsv) \
169 __macro(rocblas_ctrsv) \
170 __macro(rocblas_ztrsv) \
171 __macro(rocblas_stpsv) \
172 __macro(rocblas_dtpsv) \
173 __macro(rocblas_ctpsv) \
174 __macro(rocblas_ztpsv) \
175 __macro(rocblas_stbsv) \
176 __macro(rocblas_dtbsv) \
177 __macro(rocblas_ctbsv) \
178 __macro(rocblas_ztbsv) \
179 __macro(rocblas_ssymv) \
180 __macro(rocblas_dsymv) \
181 __macro(rocblas_csymv) \
182 __macro(rocblas_zsymv) \
183 __macro(rocblas_chemv) \
184 __macro(rocblas_zhemv) \
185 __macro(rocblas_ssbmv) \
186 __macro(rocblas_dsbmv) \
187 __macro(rocblas_chbmv) \
188 __macro(rocblas_zhbmv) \
189 __macro(rocblas_sspmv) \
190 __macro(rocblas_dspmv) \
191 __macro(rocblas_chpmv) \
192 __macro(rocblas_zhpmv) */ \
193 __macro(rocblas_sger) \
194 __macro(rocblas_dger) /* __macro(rocblas_cgeru) \
195 __macro(rocblas_cgerc) \
196 __macro(rocblas_zgeru) \
197 __macro(rocblas_zgerc) */ \
198 __macro(rocblas_ssyr) \
199 __macro(rocblas_dsyr) /* __macro(rocblas_csyr) \
200 __macro(rocblas_zsyr) \
201 __macro(rocblas_cher) \
202 __macro(rocblas_zher) \
203 __macro(rocblas_sspr) \
204 __macro(rocblas_dspr) \
205 __macro(rocblas_chpr) \
206 __macro(rocblas_zhpr) \
207 __macro(rocblas_ssyr2) \
208 __macro(rocblas_dsyr2) \
209 __macro(rocblas_csyr2) \
210 __macro(rocblas_zsyr2) \
211 __macro(rocblas_cher2) \
212 __macro(rocblas_zher2) \
213 __macro(rocblas_sspr2) \
214 __macro(rocblas_dspr2) \
215 __macro(rocblas_chpr2) \
216 __macro(rocblas_zhpr2) */ \
217 __macro(rocblas_sgemm) __macro(rocblas_dgemm) \
218 __macro(rocblas_hgemm) /* __macro(rocblas_cgemm) \
219 __macro(rocblas_zgemm) \
220 __macro(rocblas_ssyrk) \
221 __macro(rocblas_dsyrk) \
222 __macro(rocblas_csyrk) \
223 __macro(rocblas_zsyrk) \
224 __macro(rocblas_cherk) \
225 __macro(rocblas_zherk) \
226 __macro(rocblas_ssyr2k) \
227 __macro(rocblas_dsyr2k) \
228 __macro(rocblas_csyr2k) \
229 __macro(rocblas_zsyr2k) \
230 __macro(rocblas_cher2k) \
231 __macro(rocblas_zher2k) \
232 __macro(rocblas_ssyrkx) \
233 __macro(rocblas_dsyrkx) \
234 __macro(rocblas_csyrkx) \
235 __macro(rocblas_zsyrkx) \
236 __macro(rocblas_cherkx) \
237 __macro(rocblas_zherkx) \
238 __macro(rocblas_ssymm) \
239 __macro(rocblas_dsymm) \
240 __macro(rocblas_csymm) \
241 __macro(rocblas_zsymm) \
242 __macro(rocblas_chemm) \
243 __macro(rocblas_zhemm) */ \
244 __macro(rocblas_strsm) \
245 __macro(rocblas_dtrsm) /* __macro(rocblas_ctrsm) \
246 __macro(rocblas_ztrsm) \
247 __macro(rocblas_strmm) \
248 __macro(rocblas_dtrmm) \
249 __macro(rocblas_ctrmm) \
250 __macro(rocblas_ztrmm) */ \
251 __macro(rocblas_sgeam) \
252 __macro(rocblas_dgeam) /* __macro(rocblas_cgeam) \
253 __macro(rocblas_zgeam) \
254 __macro(rocblas_sdgmm) \
255 __macro(rocblas_ddgmm) \
256 __macro(rocblas_cdgmm) \
257 __macro(rocblas_zdgmm) */
258
259 STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_create_handle)
260 STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_destroy_handle)
261 STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_set_stream)
262 // STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_set_pointer_mode)
263 // STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_get_pointer_mode)
264 // STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_sgemm_batched)
265 STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_hgemm_strided_batched)
266 STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_sgemm_strided_batched)
267 // STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_dgemm_batched)
268 STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_dgemm_strided_batched)
269 // STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_cgemm_batched)
270 // STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_zgemm_batched)
271 ROCBLAS_BLAS_ROUTINE_EACH(STREAM_EXECUTOR_ROCBLAS_V2_WRAP)
272
273 } // namespace wrap
274
ToString(rocblas_status status)275 static string ToString(rocblas_status status) {
276 switch (status) {
277 case rocblas_status_success:
278 return "rocblas_status_success";
279 case rocblas_status_invalid_handle:
280 return "rocblas_status_invalid_handle";
281 case rocblas_status_not_implemented:
282 return "rocblas_status_not_implemented";
283 case rocblas_status_invalid_pointer:
284 return "rocblas_status_invalid_pointer";
285 case rocblas_status_invalid_size:
286 return "rocblas_status_invalid_size";
287 case rocblas_status_memory_error:
288 return "rocblas_status_memory_error";
289 case rocblas_status_internal_error:
290 return "rocblas_status_internal_error";
291 default:
292 return absl::StrCat("<invalid rocBLAS status: ", status, ">");
293 }
294 }
295
Init()296 bool ROCMBlas::Init() {
297 rocblas_status ret = wrap::rocblas_create_handle(parent_, &blas_);
298 if (ret != rocblas_status_success) {
299 LOG(ERROR) << "failed to create rocBLAS handle: " << ToString(ret);
300 return false;
301 }
302
303 return true;
304 }
305
ROCMBlas(gpu::GpuExecutor * parent)306 ROCMBlas::ROCMBlas(gpu::GpuExecutor *parent)
307 : parent_(CHECK_NOTNULL(parent)), blas_(nullptr) {}
308
~ROCMBlas()309 ROCMBlas::~ROCMBlas() {
310 if (blas_ != nullptr) {
311 wrap::rocblas_destroy_handle(parent_, blas_);
312 }
313 }
314
SetStream(Stream * stream)315 bool ROCMBlas::SetStream(Stream *stream) {
316 CHECK(stream != nullptr);
317 CHECK(AsGpuStreamValue(stream) != nullptr);
318 CHECK(blas_ != nullptr);
319 rocblas_status ret =
320 wrap::rocblas_set_stream(parent_, blas_, AsGpuStreamValue(stream));
321 if (ret != rocblas_status_success) {
322 LOG(ERROR) << "failed to set stream for rocBLAS calls: " << ToString(ret);
323 return false;
324 }
325
326 return true;
327 }
328
329 namespace {
330
331 // Helper functions transforming blas arguments into rocBLAS arguments.
332
ROCMBlasTranspose(blas::Transpose trans)333 rocblas_operation ROCMBlasTranspose(blas::Transpose trans) {
334 switch (trans) {
335 case blas::Transpose::kNoTranspose:
336 return rocblas_operation_none;
337 case blas::Transpose::kTranspose:
338 return rocblas_operation_transpose;
339 case blas::Transpose::kConjugateTranspose:
340 return rocblas_operation_conjugate_transpose;
341 default:
342 LOG(FATAL) << "Invalid value of blas::Transpose.";
343 }
344 }
345
ROCMBlasUpperLower(blas::UpperLower uplo)346 rocblas_fill ROCMBlasUpperLower(blas::UpperLower uplo) {
347 switch (uplo) {
348 case blas::UpperLower::kUpper:
349 return rocblas_fill_upper;
350 case blas::UpperLower::kLower:
351 return rocblas_fill_lower;
352 default:
353 LOG(FATAL) << "Invalid value of blas::UpperLower.";
354 }
355 }
356
ROCMBlasDiagonal(blas::Diagonal diag)357 rocblas_diagonal ROCMBlasDiagonal(blas::Diagonal diag) {
358 switch (diag) {
359 case blas::Diagonal::kUnit:
360 return rocblas_diagonal_unit;
361 case blas::Diagonal::kNonUnit:
362 return rocblas_diagonal_non_unit;
363 default:
364 LOG(FATAL) << "Invalid value of blas::Diagonal.";
365 }
366 }
367
ROCMBlasSide(blas::Side side)368 rocblas_side ROCMBlasSide(blas::Side side) {
369 switch (side) {
370 case blas::Side::kLeft:
371 return rocblas_side_left;
372 case blas::Side::kRight:
373 return rocblas_side_right;
374 default:
375 LOG(FATAL) << "Invalid value of blas::Side.";
376 }
377 }
378
379 } // namespace
380
381 template <typename FuncT, typename... Args>
DoBlasInternalImpl(FuncT rocblas_func,Stream * stream,bool pointer_mode_host,bool err_on_failure,Args...args)382 bool ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream,
383 bool pointer_mode_host, bool err_on_failure,
384 Args... args) {
385 mutex_lock lock{mu_};
386
387 CHECK(blas_ != nullptr);
388 if (!SetStream(stream)) {
389 return false;
390 }
391
392 rocblas_status ret = rocblas_func(parent_, blas_, args...);
393 if (err_on_failure && ret != rocblas_status_success) {
394 LOG(ERROR) << "failed to run ROCBLAS routine " << rocblas_func.kName << ": "
395 << ToString(ret);
396 }
397 return ret == rocblas_status_success;
398 }
399
DoBlasAsum(Stream * stream,uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * result)400 bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count,
401 const DeviceMemory<float> &x, int incx,
402 DeviceMemory<float> *result) {
403 return DoBlasInternal(wrap::rocblas_sasum, stream,
404 false /* = pointer_mode_host */, elem_count,
405 GpuMemory(x), incx, GpuMemoryMutable(result));
406 }
407
DoBlasAsum(Stream * stream,uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * result)408 bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count,
409 const DeviceMemory<double> &x, int incx,
410 DeviceMemory<double> *result) {
411 return DoBlasInternal(wrap::rocblas_dasum, stream,
412 false /* = pointer_mode_host */, elem_count,
413 GpuMemory(x), incx, GpuMemoryMutable(result));
414 }
415
DoBlasAsum(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<float> * result)416 bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count,
417 const DeviceMemory<std::complex<float>> &x, int incx,
418 DeviceMemory<float> *result) {
419 LOG(ERROR) << "rocBLAS does not currently support the ASUM operation "
420 << "for the \"complex<float>\" dataype";
421 return false;
422 }
423
DoBlasAsum(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<double> * result)424 bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count,
425 const DeviceMemory<std::complex<double>> &x, int incx,
426 DeviceMemory<double> *result) {
427 LOG(ERROR) << "rocBLAS does not currently support the ASUM operation "
428 << "for the \"complex<double>\" dataype";
429 return false;
430 }
431
DoBlasAxpy(Stream * stream,uint64 elem_count,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * y,int incy)432 bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha,
433 const DeviceMemory<float> &x, int incx,
434 DeviceMemory<float> *y, int incy) {
435 return DoBlasInternal(wrap::rocblas_saxpy, stream,
436 true /* = pointer_mode_host */, elem_count, &alpha,
437 GpuMemory(x), incx, GpuMemoryMutable(y), incy);
438 }
439
DoBlasAxpy(Stream * stream,uint64 elem_count,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * y,int incy)440 bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha,
441 const DeviceMemory<double> &x, int incx,
442 DeviceMemory<double> *y, int incy) {
443 return DoBlasInternal(wrap::rocblas_daxpy, stream,
444 true /* = pointer_mode_host */, elem_count, &alpha,
445 GpuMemory(x), incx, GpuMemoryMutable(y), incy);
446 }
447
DoBlasAxpy(Stream * stream,uint64 elem_count,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * y,int incy)448 bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count,
449 std::complex<float> alpha,
450 const DeviceMemory<std::complex<float>> &x, int incx,
451 DeviceMemory<std::complex<float>> *y, int incy) {
452 LOG(ERROR) << "rocBLAS does not currently support the AXPY operation "
453 << "for the \"complex<float>\" dataype";
454 return false;
455 }
456
DoBlasAxpy(Stream * stream,uint64 elem_count,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * y,int incy)457 bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count,
458 std::complex<double> alpha,
459 const DeviceMemory<std::complex<double>> &x, int incx,
460 DeviceMemory<std::complex<double>> *y, int incy) {
461 LOG(ERROR) << "rocBLAS does not currently support the AXPY operation "
462 << "for the \"complex<double>\" dataype";
463 return false;
464 }
465
DoBlasCopy(Stream * stream,uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * y,int incy)466 bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count,
467 const DeviceMemory<float> &x, int incx,
468 DeviceMemory<float> *y, int incy) {
469 return DoBlasInternal(wrap::rocblas_scopy, stream,
470 true /* = pointer_mode_host */, elem_count,
471 GpuMemory(x), incx, GpuMemoryMutable(y), incy);
472 }
473
DoBlasCopy(Stream * stream,uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * y,int incy)474 bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count,
475 const DeviceMemory<double> &x, int incx,
476 DeviceMemory<double> *y, int incy) {
477 return DoBlasInternal(wrap::rocblas_dcopy, stream,
478 true /* = pointer_mode_host */, elem_count,
479 GpuMemory(x), incx, GpuMemoryMutable(y), incy);
480 }
481
DoBlasCopy(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * y,int incy)482 bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count,
483 const DeviceMemory<std::complex<float>> &x, int incx,
484 DeviceMemory<std::complex<float>> *y, int incy) {
485 LOG(ERROR) << "rocBLAS does not currently support the COPY operation "
486 << "for the \"complex<float>\" dataype";
487 return false;
488 }
489
DoBlasCopy(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * y,int incy)490 bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count,
491 const DeviceMemory<std::complex<double>> &x, int incx,
492 DeviceMemory<std::complex<double>> *y, int incy) {
493 LOG(ERROR) << "rocBLAS does not currently support the COPY operation "
494 << "for the \"complex<double>\" dataype";
495 return false;
496 }
497
DoBlasDot(Stream * stream,uint64 elem_count,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * result)498 bool ROCMBlas::DoBlasDot(Stream *stream, uint64 elem_count,
499 const DeviceMemory<float> &x, int incx,
500 const DeviceMemory<float> &y, int incy,
501 DeviceMemory<float> *result) {
502 return DoBlasInternal(
503 wrap::rocblas_sdot, stream, false /* = pointer_mode_host */, elem_count,
504 GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(result));
505 }
506
DoBlasDot(Stream * stream,uint64 elem_count,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * result)507 bool ROCMBlas::DoBlasDot(Stream *stream, uint64 elem_count,
508 const DeviceMemory<double> &x, int incx,
509 const DeviceMemory<double> &y, int incy,
510 DeviceMemory<double> *result) {
511 return DoBlasInternal(
512 wrap::rocblas_ddot, stream, false /* = pointer_mode_host */, elem_count,
513 GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(result));
514 }
515
DoBlasDotc(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * result)516 bool ROCMBlas::DoBlasDotc(Stream *stream, uint64 elem_count,
517 const DeviceMemory<std::complex<float>> &x, int incx,
518 const DeviceMemory<std::complex<float>> &y, int incy,
519 DeviceMemory<std::complex<float>> *result) {
520 LOG(ERROR) << "rocBLAS does not currently support the DOT operation "
521 << "for the \"complex<float>\" dataype";
522 return false;
523 }
524
DoBlasDotc(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * result)525 bool ROCMBlas::DoBlasDotc(Stream *stream, uint64 elem_count,
526 const DeviceMemory<std::complex<double>> &x, int incx,
527 const DeviceMemory<std::complex<double>> &y, int incy,
528 DeviceMemory<std::complex<double>> *result) {
529 LOG(ERROR) << "rocBLAS does not currently support the DOT operation "
530 << "for the \"complex<double>\" dataype";
531 return false;
532 }
533
DoBlasDotu(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * result)534 bool ROCMBlas::DoBlasDotu(Stream *stream, uint64 elem_count,
535 const DeviceMemory<std::complex<float>> &x, int incx,
536 const DeviceMemory<std::complex<float>> &y, int incy,
537 DeviceMemory<std::complex<float>> *result) {
538 LOG(ERROR) << "rocBLAS does not currently support the DOT operation "
539 << "for the \"complex<float>\" dataype";
540 return false;
541 }
542
DoBlasDotu(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * result)543 bool ROCMBlas::DoBlasDotu(Stream *stream, uint64 elem_count,
544 const DeviceMemory<std::complex<double>> &x, int incx,
545 const DeviceMemory<std::complex<double>> &y, int incy,
546 DeviceMemory<std::complex<double>> *result) {
547 LOG(ERROR) << "rocBLAS does not currently support the DOT operation "
548 << "for the \"complex<double>\" dataype";
549 return false;
550 }
551
DoBlasNrm2(Stream * stream,uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * result)552 bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
553 const DeviceMemory<float> &x, int incx,
554 DeviceMemory<float> *result) {
555 return DoBlasInternal(wrap::rocblas_snrm2, stream,
556 false /* = pointer_mode_host */, elem_count,
557 GpuMemory(x), incx, GpuMemoryMutable(result));
558 }
559
DoBlasNrm2(Stream * stream,uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * result)560 bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
561 const DeviceMemory<double> &x, int incx,
562 DeviceMemory<double> *result) {
563 return DoBlasInternal(wrap::rocblas_dnrm2, stream,
564 false /* = pointer_mode_host */, elem_count,
565 GpuMemory(x), incx, GpuMemoryMutable(result));
566 }
567
DoBlasNrm2(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<float> * result)568 bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
569 const DeviceMemory<std::complex<float>> &x, int incx,
570 DeviceMemory<float> *result) {
571 LOG(ERROR) << "rocBLAS does not currently support the NRM2 operation "
572 << "for the \"complex<float>\" dataype";
573 return false;
574 }
575
DoBlasNrm2(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<double> * result)576 bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
577 const DeviceMemory<std::complex<double>> &x, int incx,
578 DeviceMemory<double> *result) {
579 LOG(ERROR) << "rocBLAS does not currently support the NRM2 operation "
580 << "for the \"complex<double>\" dataype";
581 return false;
582 }
583
DoBlasRot(Stream * stream,uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy,float c,float s)584 bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count,
585 DeviceMemory<float> *x, int incx,
586 DeviceMemory<float> *y, int incy, float c, float s) {
587 LOG(ERROR) << "rocBLAS does not currently support the ROT operation "
588 << "for the \"float\" dataype";
589 return false;
590 }
591
DoBlasRot(Stream * stream,uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy,double c,double s)592 bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count,
593 DeviceMemory<double> *x, int incx,
594 DeviceMemory<double> *y, int incy, double c,
595 double s) {
596 LOG(ERROR) << "rocBLAS does not currently support the ROT operation "
597 << "for the \"double\" dataype";
598 return false;
599 }
600
DoBlasRot(Stream * stream,uint64 elem_count,DeviceMemory<std::complex<float>> * x,int incx,DeviceMemory<std::complex<float>> * y,int incy,float c,float s)601 bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count,
602 DeviceMemory<std::complex<float>> *x, int incx,
603 DeviceMemory<std::complex<float>> *y, int incy,
604 float c, float s) {
605 LOG(ERROR) << "rocBLAS does not currently support the ROT operation "
606 << "for the \"complex<float>\" dataype";
607 return false;
608 }
609
DoBlasRot(Stream * stream,uint64 elem_count,DeviceMemory<std::complex<double>> * x,int incx,DeviceMemory<std::complex<double>> * y,int incy,double c,double s)610 bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count,
611 DeviceMemory<std::complex<double>> *x, int incx,
612 DeviceMemory<std::complex<double>> *y, int incy,
613 double c, double s) {
614 LOG(ERROR) << "rocBLAS does not currently support the ROT operation "
615 << "for the \"complex<double>\" dataype";
616 return false;
617 }
618
DoBlasRotg(Stream * stream,DeviceMemory<float> * a,DeviceMemory<float> * b,DeviceMemory<float> * c,DeviceMemory<float> * s)619 bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory<float> *a,
620 DeviceMemory<float> *b, DeviceMemory<float> *c,
621 DeviceMemory<float> *s) {
622 LOG(ERROR) << "rocBLAS does not currently support the ROTG operation "
623 << "for the \"float\" dataype";
624 return false;
625 }
626
DoBlasRotg(Stream * stream,DeviceMemory<double> * a,DeviceMemory<double> * b,DeviceMemory<double> * c,DeviceMemory<double> * s)627 bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory<double> *a,
628 DeviceMemory<double> *b, DeviceMemory<double> *c,
629 DeviceMemory<double> *s) {
630 LOG(ERROR) << "rocBLAS does not currently support the ROTG operation "
631 << "for the \"double\" dataype";
632 return false;
633 }
634
DoBlasRotg(Stream * stream,DeviceMemory<std::complex<float>> * a,DeviceMemory<std::complex<float>> * b,DeviceMemory<float> * c,DeviceMemory<std::complex<float>> * s)635 bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory<std::complex<float>> *a,
636 DeviceMemory<std::complex<float>> *b,
637 DeviceMemory<float> *c,
638 DeviceMemory<std::complex<float>> *s) {
639 LOG(ERROR) << "rocBLAS does not currently support the ROTG operation "
640 << "for the \"complex<float>\" dataype";
641 return false;
642 }
643
DoBlasRotg(Stream * stream,DeviceMemory<std::complex<double>> * a,DeviceMemory<std::complex<double>> * b,DeviceMemory<double> * c,DeviceMemory<std::complex<double>> * s)644 bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory<std::complex<double>> *a,
645 DeviceMemory<std::complex<double>> *b,
646 DeviceMemory<double> *c,
647 DeviceMemory<std::complex<double>> *s) {
648 LOG(ERROR) << "rocBLAS does not currently support the ROTG operation "
649 << "for the \"complex<double>\" dataype";
650 return false;
651 }
652
DoBlasRotm(Stream * stream,uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy,const DeviceMemory<float> & param)653 bool ROCMBlas::DoBlasRotm(Stream *stream, uint64 elem_count,
654 DeviceMemory<float> *x, int incx,
655 DeviceMemory<float> *y, int incy,
656 const DeviceMemory<float> ¶m) {
657 LOG(ERROR) << "rocBLAS does not currently support the ROTM operation "
658 << "for the \"float\" dataype";
659 return false;
660 }
661
DoBlasRotm(Stream * stream,uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy,const DeviceMemory<double> & param)662 bool ROCMBlas::DoBlasRotm(Stream *stream, uint64 elem_count,
663 DeviceMemory<double> *x, int incx,
664 DeviceMemory<double> *y, int incy,
665 const DeviceMemory<double> ¶m) {
666 LOG(ERROR) << "rocBLAS does not currently support the ROTM operation "
667 << "for the \"double\" dataype";
668 return false;
669 }
670
DoBlasRotmg(Stream * stream,DeviceMemory<float> * d1,DeviceMemory<float> * d2,DeviceMemory<float> * x1,const DeviceMemory<float> & y1,DeviceMemory<float> * param)671 bool ROCMBlas::DoBlasRotmg(Stream *stream, DeviceMemory<float> *d1,
672 DeviceMemory<float> *d2, DeviceMemory<float> *x1,
673 const DeviceMemory<float> &y1,
674 DeviceMemory<float> *param) {
675 LOG(ERROR) << "rocBLAS does not currently support the ROTMG operation "
676 << "for the \"float\" dataype";
677 return false;
678 }
679
DoBlasRotmg(Stream * stream,DeviceMemory<double> * d1,DeviceMemory<double> * d2,DeviceMemory<double> * x1,const DeviceMemory<double> & y1,DeviceMemory<double> * param)680 bool ROCMBlas::DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1,
681 DeviceMemory<double> *d2, DeviceMemory<double> *x1,
682 const DeviceMemory<double> &y1,
683 DeviceMemory<double> *param) {
684 LOG(ERROR) << "rocBLAS does not currently support the ROTMG operation "
685 << "for the \"double\" dataype";
686 return false;
687 }
688
DoBlasScal(Stream * stream,uint64 elem_count,float alpha,DeviceMemory<float> * x,int incx)689 bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
690 DeviceMemory<float> *x, int incx) {
691 return DoBlasInternal(wrap::rocblas_sscal, stream,
692 true /* = pointer_mode_host */, elem_count, &alpha,
693 GpuMemoryMutable(x), incx);
694 }
695
DoBlasScal(Stream * stream,uint64 elem_count,double alpha,DeviceMemory<double> * x,int incx)696 bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
697 DeviceMemory<double> *x, int incx) {
698 return DoBlasInternal(wrap::rocblas_dscal, stream,
699 true /* = pointer_mode_host */, elem_count, &alpha,
700 GpuMemoryMutable(x), incx);
701 }
702
DoBlasScal(Stream * stream,uint64 elem_count,float alpha,DeviceMemory<std::complex<float>> * x,int incx)703 bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
704 DeviceMemory<std::complex<float>> *x, int incx) {
705 LOG(ERROR) << "rocBLAS does not currently support the SCAL operation "
706 << "for the \"complex<float>\" dataype";
707 return false;
708 }
709
DoBlasScal(Stream * stream,uint64 elem_count,double alpha,DeviceMemory<std::complex<double>> * x,int incx)710 bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
711 DeviceMemory<std::complex<double>> *x, int incx) {
712 LOG(ERROR) << "rocBLAS does not currently support the SCAL operation "
713 << "for the \"complex<double>\" dataype";
714 return false;
715 }
716
DoBlasScal(Stream * stream,uint64 elem_count,std::complex<float> alpha,DeviceMemory<std::complex<float>> * x,int incx)717 bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count,
718 std::complex<float> alpha,
719 DeviceMemory<std::complex<float>> *x, int incx) {
720 LOG(ERROR) << "rocBLAS does not currently support the SCAL operation "
721 << "for the \"complex<float>\" dataype";
722 return false;
723 }
724
DoBlasScal(Stream * stream,uint64 elem_count,std::complex<double> alpha,DeviceMemory<std::complex<double>> * x,int incx)725 bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count,
726 std::complex<double> alpha,
727 DeviceMemory<std::complex<double>> *x, int incx) {
728 LOG(ERROR) << "rocBLAS does not currently support the SCAL operation "
729 << "for the \"complex<double>\" dataype";
730 return false;
731 }
732
DoBlasSwap(Stream * stream,uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy)733 bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count,
734 DeviceMemory<float> *x, int incx,
735 DeviceMemory<float> *y, int incy) {
736 return DoBlasInternal(wrap::rocblas_sswap, stream,
737 true /* = pointer_mode_host */, elem_count,
738 GpuMemoryMutable(x), incx, GpuMemoryMutable(y), incy);
739 }
740
DoBlasSwap(Stream * stream,uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy)741 bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count,
742 DeviceMemory<double> *x, int incx,
743 DeviceMemory<double> *y, int incy) {
744 return DoBlasInternal(wrap::rocblas_dswap, stream,
745 true /* = pointer_mode_host */, elem_count,
746 GpuMemoryMutable(x), incx, GpuMemoryMutable(y), incy);
747 }
748
DoBlasSwap(Stream * stream,uint64 elem_count,DeviceMemory<std::complex<float>> * x,int incx,DeviceMemory<std::complex<float>> * y,int incy)749 bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count,
750 DeviceMemory<std::complex<float>> *x, int incx,
751 DeviceMemory<std::complex<float>> *y, int incy) {
752 LOG(ERROR) << "rocBLAS does not currently support the SWAP operation "
753 << "for the \"complex<float>\" dataype";
754 return false;
755 }
756
DoBlasSwap(Stream * stream,uint64 elem_count,DeviceMemory<std::complex<double>> * x,int incx,DeviceMemory<std::complex<double>> * y,int incy)757 bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count,
758 DeviceMemory<std::complex<double>> *x, int incx,
759 DeviceMemory<std::complex<double>> *y, int incy) {
760 LOG(ERROR) << "rocBLAS does not currently support the SWAP operation "
761 << "for the \"complex<double>\" dataype";
762 return false;
763 }
764
DoBlasIamax(Stream * stream,uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<int> * result)765 bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count,
766 const DeviceMemory<float> &x, int incx,
767 DeviceMemory<int> *result) {
768 return DoBlasInternal(wrap::rocblas_isamax, stream,
769 false /* = pointer_mode_host */, elem_count,
770 GpuMemory(x), incx, GpuMemoryMutable(result));
771 }
772
DoBlasIamax(Stream * stream,uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<int> * result)773 bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count,
774 const DeviceMemory<double> &x, int incx,
775 DeviceMemory<int> *result) {
776 return DoBlasInternal(wrap::rocblas_idamax, stream,
777 false /* = pointer_mode_host */, elem_count,
778 GpuMemory(x), incx, GpuMemoryMutable(result));
779 }
780
DoBlasIamax(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<int> * result)781 bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count,
782 const DeviceMemory<std::complex<float>> &x, int incx,
783 DeviceMemory<int> *result) {
784 LOG(ERROR) << "rocBLAS does not currently support the AMAX operation "
785 << "for the \"complex<float>\" dataype";
786 return false;
787 }
788
DoBlasIamax(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<int> * result)789 bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count,
790 const DeviceMemory<std::complex<double>> &x,
791 int incx, DeviceMemory<int> *result) {
792 LOG(ERROR) << "rocBLAS does not currently support the AMAX operation "
793 << "for the \"complex<double>\" dataype";
794 return false;
795 }
796
DoBlasIamin(Stream * stream,uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<int> * result)797 bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count,
798 const DeviceMemory<float> &x, int incx,
799 DeviceMemory<int> *result) {
800 return DoBlasInternal(
801 wrap::rocblas_isamin, stream, false /* = pointer_mode_host */, elem_count,
802 GpuComplex(GpuMemory(x)), incx, GpuMemoryMutable(result));
803 }
804
DoBlasIamin(Stream * stream,uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<int> * result)805 bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count,
806 const DeviceMemory<double> &x, int incx,
807 DeviceMemory<int> *result) {
808 return DoBlasInternal(
809 wrap::rocblas_idamin, stream, false /* = pointer_mode_host */, elem_count,
810 GpuComplex(GpuMemory(x)), incx, GpuMemoryMutable(result));
811 }
812
DoBlasIamin(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<int> * result)813 bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count,
814 const DeviceMemory<std::complex<float>> &x, int incx,
815 DeviceMemory<int> *result) {
816 LOG(ERROR) << "rocBLAS does not currently support the AMIN operation "
817 << "for the \"complex<float>\" dataype";
818 return false;
819 }
820
DoBlasIamin(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<int> * result)821 bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count,
822 const DeviceMemory<std::complex<double>> &x,
823 int incx, DeviceMemory<int> *result) {
824 LOG(ERROR) << "rocBLAS does not currently support the AMIN operation "
825 << "for the \"complex<double>\" dataype";
826 return false;
827 }
828
DoBlasGbmv(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)829 bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
830 uint64 n, uint64 kl, uint64 ku, float alpha,
831 const DeviceMemory<float> &a, int lda,
832 const DeviceMemory<float> &x, int incx, float beta,
833 DeviceMemory<float> *y, int incy) {
834 LOG(ERROR) << "rocBLAS does not currently support the GBMV operation "
835 << "for the \"float\" dataype";
836 return false;
837 }
838
DoBlasGbmv(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)839 bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
840 uint64 n, uint64 kl, uint64 ku, double alpha,
841 const DeviceMemory<double> &a, int lda,
842 const DeviceMemory<double> &x, int incx, double beta,
843 DeviceMemory<double> *y, int incy) {
844 LOG(ERROR) << "rocBLAS does not currently support the GBMV operation "
845 << "for the \"double\" dataype";
846 return false;
847 }
848
DoBlasGbmv(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)849 bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
850 uint64 n, uint64 kl, uint64 ku,
851 std::complex<float> alpha,
852 const DeviceMemory<std::complex<float>> &a, int lda,
853 const DeviceMemory<std::complex<float>> &x, int incx,
854 std::complex<float> beta,
855 DeviceMemory<std::complex<float>> *y, int incy) {
856 LOG(ERROR) << "rocBLAS does not currently support the GBMV operation "
857 << "for the \"complex<float>\" dataype";
858 return false;
859 }
860
DoBlasGbmv(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)861 bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
862 uint64 n, uint64 kl, uint64 ku,
863 std::complex<double> alpha,
864 const DeviceMemory<std::complex<double>> &a, int lda,
865 const DeviceMemory<std::complex<double>> &x, int incx,
866 std::complex<double> beta,
867 DeviceMemory<std::complex<double>> *y, int incy) {
868 LOG(ERROR) << "rocBLAS does not currently support the GBMV operation "
869 << "for the \"complex<double>\" dataype";
870 return false;
871 }
872
DoBlasGemv(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)873 bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
874 uint64 n, float alpha, const DeviceMemory<float> &a,
875 int lda, const DeviceMemory<float> &x, int incx,
876 float beta, DeviceMemory<float> *y, int incy) {
877 return DoBlasInternal(
878 wrap::rocblas_sgemv, stream, true /* = pointer_mode_host */,
879 ROCMBlasTranspose(trans), m, n, &alpha, GpuMemory(a), lda, GpuMemory(x),
880 incx, &beta, GpuMemoryMutable(y), incy);
881 }
882
DoBlasGemv(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)883 bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
884 uint64 n, double alpha, const DeviceMemory<double> &a,
885 int lda, const DeviceMemory<double> &x, int incx,
886 double beta, DeviceMemory<double> *y, int incy) {
887 return DoBlasInternal(
888 wrap::rocblas_dgemv, stream, true /* = pointer_mode_host */,
889 ROCMBlasTranspose(trans), m, n, &alpha, GpuMemory(a), lda, GpuMemory(x),
890 incx, &beta, GpuMemoryMutable(y), incy);
891 }
892
DoBlasGemv(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)893 bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
894 uint64 n, std::complex<float> alpha,
895 const DeviceMemory<std::complex<float>> &a, int lda,
896 const DeviceMemory<std::complex<float>> &x, int incx,
897 std::complex<float> beta,
898 DeviceMemory<std::complex<float>> *y, int incy) {
899 LOG(ERROR) << "rocBLAS does not currently support the GEMV operation "
900 << "for the \"complex<float>\" dataype";
901 return false;
902 }
903
DoBlasGemv(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)904 bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
905 uint64 n, std::complex<double> alpha,
906 const DeviceMemory<std::complex<double>> &a, int lda,
907 const DeviceMemory<std::complex<double>> &x, int incx,
908 std::complex<double> beta,
909 DeviceMemory<std::complex<double>> *y, int incy) {
910 LOG(ERROR) << "rocBLAS does not currently support the GEMV operation "
911 << "for the \"complex<double>\" dataype";
912 return false;
913 }
914
DoBlasGer(Stream * stream,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * a,int lda)915 bool ROCMBlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha,
916 const DeviceMemory<float> &x, int incx,
917 const DeviceMemory<float> &y, int incy,
918 DeviceMemory<float> *a, int lda) {
919 return DoBlasInternal(
920 wrap::rocblas_sger, stream, true /* = pointer_mode_host */, m, n, &alpha,
921 GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(a), lda);
922 }
923
DoBlasGer(Stream * stream,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * a,int lda)924 bool ROCMBlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha,
925 const DeviceMemory<double> &x, int incx,
926 const DeviceMemory<double> &y, int incy,
927 DeviceMemory<double> *a, int lda) {
928 return DoBlasInternal(
929 wrap::rocblas_dger, stream, true /* = pointer_mode_host */, m, n, &alpha,
930 GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(a), lda);
931 }
932
DoBlasGerc(Stream * stream,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)933 bool ROCMBlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n,
934 std::complex<float> alpha,
935 const DeviceMemory<std::complex<float>> &x, int incx,
936 const DeviceMemory<std::complex<float>> &y, int incy,
937 DeviceMemory<std::complex<float>> *a, int lda) {
938 LOG(ERROR) << "rocBLAS does not currently support the GER operation "
939 << "for the \"complex<float>\" dataype";
940 return false;
941 }
942
DoBlasGerc(Stream * stream,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)943 bool ROCMBlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n,
944 std::complex<double> alpha,
945 const DeviceMemory<std::complex<double>> &x, int incx,
946 const DeviceMemory<std::complex<double>> &y, int incy,
947 DeviceMemory<std::complex<double>> *a, int lda) {
948 LOG(ERROR) << "rocBLAS does not currently support the GER operation "
949 << "for the \"complex<double>\" dataype";
950 return false;
951 }
952
DoBlasGeru(Stream * stream,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)953 bool ROCMBlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n,
954 std::complex<float> alpha,
955 const DeviceMemory<std::complex<float>> &x, int incx,
956 const DeviceMemory<std::complex<float>> &y, int incy,
957 DeviceMemory<std::complex<float>> *a, int lda) {
958 LOG(ERROR) << "rocBLAS does not currently support the GERU operation "
959 << "for the \"complex<float>\" dataype";
960 return false;
961 }
962
DoBlasGeru(Stream * stream,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)963 bool ROCMBlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n,
964 std::complex<double> alpha,
965 const DeviceMemory<std::complex<double>> &x, int incx,
966 const DeviceMemory<std::complex<double>> &y, int incy,
967 DeviceMemory<std::complex<double>> *a, int lda) {
968 LOG(ERROR) << "rocBLAS does not currently support the GERU operation "
969 << "for the \"complex<double>\" dataype";
970 return false;
971 }
972
DoBlasHbmv(Stream * stream,blas::UpperLower uplo,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)973 bool ROCMBlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
974 uint64 k, std::complex<float> alpha,
975 const DeviceMemory<std::complex<float>> &a, int lda,
976 const DeviceMemory<std::complex<float>> &x, int incx,
977 std::complex<float> beta,
978 DeviceMemory<std::complex<float>> *y, int incy) {
979 LOG(ERROR) << "rocBLAS does not currently support the HBMV operation "
980 << "for the \"complex<float>\" dataype";
981 return false;
982 }
983
DoBlasHbmv(Stream * stream,blas::UpperLower uplo,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)984 bool ROCMBlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
985 uint64 k, std::complex<double> alpha,
986 const DeviceMemory<std::complex<double>> &a, int lda,
987 const DeviceMemory<std::complex<double>> &x, int incx,
988 std::complex<double> beta,
989 DeviceMemory<std::complex<double>> *y, int incy) {
990 LOG(ERROR) << "rocBLAS does not currently support the HBMV operation "
991 << "for the \"complex<double>\" dataype";
992 return false;
993 }
994
DoBlasHemv(Stream * stream,blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)995 bool ROCMBlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
996 std::complex<float> alpha,
997 const DeviceMemory<std::complex<float>> &a, int lda,
998 const DeviceMemory<std::complex<float>> &x, int incx,
999 std::complex<float> beta,
1000 DeviceMemory<std::complex<float>> *y, int incy) {
1001 LOG(ERROR) << "rocBLAS does not currently support the HEMV operation "
1002 << "for the \"complex<float>\" dataype";
1003 return false;
1004 }
1005
DoBlasHemv(Stream * stream,blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)1006 bool ROCMBlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
1007 std::complex<double> alpha,
1008 const DeviceMemory<std::complex<double>> &a, int lda,
1009 const DeviceMemory<std::complex<double>> &x, int incx,
1010 std::complex<double> beta,
1011 DeviceMemory<std::complex<double>> *y, int incy) {
1012 LOG(ERROR) << "rocBLAS does not currently support the HEMV operation "
1013 << "for the \"complex<double>\" dataype";
1014 return false;
1015 }
1016
DoBlasHer(Stream * stream,blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * a,int lda)1017 bool ROCMBlas::DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
1018 float alpha,
1019 const DeviceMemory<std::complex<float>> &x, int incx,
1020 DeviceMemory<std::complex<float>> *a, int lda) {
1021 LOG(ERROR) << "rocBLAS does not currently support the HER operation "
1022 << "for the \"complex<float>\" dataype";
1023 return false;
1024 }
1025
DoBlasHer(Stream * stream,blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * a,int lda)1026 bool ROCMBlas::DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
1027 double alpha,
1028 const DeviceMemory<std::complex<double>> &x, int incx,
1029 DeviceMemory<std::complex<double>> *a, int lda) {
1030 LOG(ERROR) << "rocBLAS does not currently support the HER operation "
1031 << "for the \"complex<double>\" dataype";
1032 return false;
1033 }
1034
DoBlasHer2(Stream * stream,blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)1035 bool ROCMBlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
1036 std::complex<float> alpha,
1037 const DeviceMemory<std::complex<float>> &x, int incx,
1038 const DeviceMemory<std::complex<float>> &y, int incy,
1039 DeviceMemory<std::complex<float>> *a, int lda) {
1040 LOG(ERROR) << "rocBLAS does not currently support the HER2 operation "
1041 << "for the \"complex<float>\" dataype";
1042 return false;
1043 }
1044
DoBlasHer2(Stream * stream,blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)1045 bool ROCMBlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
1046 std::complex<double> alpha,
1047 const DeviceMemory<std::complex<double>> &x, int incx,
1048 const DeviceMemory<std::complex<double>> &y, int incy,
1049 DeviceMemory<std::complex<double>> *a, int lda) {
1050 LOG(ERROR) << "rocBLAS does not currently support the HER2 operation "
1051 << "for the \"complex<double>\" dataype";
1052 return false;
1053 }
1054
DoBlasHpmv(Stream * stream,blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & ap,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)1055 bool ROCMBlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1056 std::complex<float> alpha,
1057 const DeviceMemory<std::complex<float>> &ap,
1058 const DeviceMemory<std::complex<float>> &x, int incx,
1059 std::complex<float> beta,
1060 DeviceMemory<std::complex<float>> *y, int incy) {
1061 LOG(ERROR) << "rocBLAS does not currently support the HPMV operation "
1062 << "for the \"complex<float>\" dataype";
1063 return false;
1064 }
1065
DoBlasHpmv(Stream * stream,blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & ap,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)1066 bool ROCMBlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1067 std::complex<double> alpha,
1068 const DeviceMemory<std::complex<double>> &ap,
1069 const DeviceMemory<std::complex<double>> &x, int incx,
1070 std::complex<double> beta,
1071 DeviceMemory<std::complex<double>> *y, int incy) {
1072 LOG(ERROR) << "rocBLAS does not currently support the HPMV operation "
1073 << "for the \"complex<double>\" dataype";
1074 return false;
1075 }
1076
DoBlasHpr(Stream * stream,blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * ap)1077 bool ROCMBlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
1078 float alpha,
1079 const DeviceMemory<std::complex<float>> &x, int incx,
1080 DeviceMemory<std::complex<float>> *ap) {
1081 LOG(ERROR) << "rocBLAS does not currently support the HPR operation "
1082 << "for the \"complex<float>\" dataype";
1083 return false;
1084 }
1085
DoBlasHpr(Stream * stream,blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * ap)1086 bool ROCMBlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
1087 double alpha,
1088 const DeviceMemory<std::complex<double>> &x, int incx,
1089 DeviceMemory<std::complex<double>> *ap) {
1090 LOG(ERROR) << "rocBLAS does not currently support the HPR operation "
1091 << "for the \"complex<double>\" dataype";
1092 return false;
1093 }
1094
DoBlasHpr2(Stream * stream,blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * ap)1095 bool ROCMBlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1096 std::complex<float> alpha,
1097 const DeviceMemory<std::complex<float>> &x, int incx,
1098 const DeviceMemory<std::complex<float>> &y, int incy,
1099 DeviceMemory<std::complex<float>> *ap) {
1100 LOG(ERROR) << "rocBLAS does not currently support the HPR2 operation "
1101 << "for the \"complex<float>\" dataype";
1102 return false;
1103 }
1104
DoBlasHpr2(Stream * stream,blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * ap)1105 bool ROCMBlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1106 std::complex<double> alpha,
1107 const DeviceMemory<std::complex<double>> &x, int incx,
1108 const DeviceMemory<std::complex<double>> &y, int incy,
1109 DeviceMemory<std::complex<double>> *ap) {
1110 LOG(ERROR) << "rocBLAS does not currently support the HPR2 operation "
1111 << "for the \"complex<double>\" dataype";
1112 return false;
1113 }
1114
DoBlasSbmv(Stream * stream,blas::UpperLower uplo,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)1115 bool ROCMBlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1116 uint64 k, float alpha, const DeviceMemory<float> &a,
1117 int lda, const DeviceMemory<float> &x, int incx,
1118 float beta, DeviceMemory<float> *y, int incy) {
1119 LOG(ERROR) << "rocBLAS does not currently support the SBMV operation "
1120 << "for the \"complex<float>\" dataype";
1121
1122 return false;
1123 }
1124
DoBlasSbmv(Stream * stream,blas::UpperLower uplo,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)1125 bool ROCMBlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1126 uint64 k, double alpha, const DeviceMemory<double> &a,
1127 int lda, const DeviceMemory<double> &x, int incx,
1128 double beta, DeviceMemory<double> *y, int incy) {
1129 LOG(ERROR) << "rocBLAS does not currently support the SBMV operation "
1130 << "for the \"complex<double>\" dataype";
1131 return false;
1132 }
1133
DoBlasSpmv(Stream * stream,blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & ap,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)1134 bool ROCMBlas::DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1135 float alpha, const DeviceMemory<float> &ap,
1136 const DeviceMemory<float> &x, int incx, float beta,
1137 DeviceMemory<float> *y, int incy) {
1138 LOG(ERROR) << "rocBLAS does not currently support the SPMV operation "
1139 << "for the \"float\" dataype";
1140 return false;
1141 }
1142
DoBlasSpmv(Stream * stream,blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & ap,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)1143 bool ROCMBlas::DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1144 double alpha, const DeviceMemory<double> &ap,
1145 const DeviceMemory<double> &x, int incx, double beta,
1146 DeviceMemory<double> *y, int incy) {
1147 LOG(ERROR) << "rocBLAS does not currently support the SPMV operation "
1148 << "for the \"double\" dataype";
1149 return false;
1150 }
1151
DoBlasSpr(Stream * stream,blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * ap)1152 bool ROCMBlas::DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
1153 float alpha, const DeviceMemory<float> &x, int incx,
1154 DeviceMemory<float> *ap) {
1155 LOG(ERROR) << "rocBLAS does not currently support the SPR operation "
1156 << "for the \"float\" dataype";
1157 return false;
1158 }
1159
DoBlasSpr(Stream * stream,blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * ap)1160 bool ROCMBlas::DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
1161 double alpha, const DeviceMemory<double> &x, int incx,
1162 DeviceMemory<double> *ap) {
1163 LOG(ERROR) << "rocBLAS does not currently support the SPR operation "
1164 << "for the \"double\" dataype";
1165 return false;
1166 }
1167
DoBlasSpr2(Stream * stream,blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * ap)1168 bool ROCMBlas::DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1169 float alpha, const DeviceMemory<float> &x, int incx,
1170 const DeviceMemory<float> &y, int incy,
1171 DeviceMemory<float> *ap) {
1172 LOG(ERROR) << "rocBLAS does not currently support the SPR2 operation "
1173 << "for the \"float\" dataype";
1174 return false;
1175 }
1176
DoBlasSpr2(Stream * stream,blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * ap)1177 bool ROCMBlas::DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1178 double alpha, const DeviceMemory<double> &x, int incx,
1179 const DeviceMemory<double> &y, int incy,
1180 DeviceMemory<double> *ap) {
1181 LOG(ERROR) << "rocBLAS does not currently support the SPR2 operation "
1182 << "for the \"double\" dataype";
1183 return false;
1184 }
1185
DoBlasSymv(Stream * stream,blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)1186 bool ROCMBlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
1187 float alpha, const DeviceMemory<float> &a, int lda,
1188 const DeviceMemory<float> &x, int incx, float beta,
1189 DeviceMemory<float> *y, int incy) {
1190 LOG(ERROR) << "rocBLAS does not currently support the SYMV operation "
1191 << "for the \"float\" dataype";
1192 return false;
1193 }
1194
DoBlasSymv(Stream * stream,blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)1195 bool ROCMBlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
1196 double alpha, const DeviceMemory<double> &a, int lda,
1197 const DeviceMemory<double> &x, int incx, double beta,
1198 DeviceMemory<double> *y, int incy) {
1199 LOG(ERROR) << "rocBLAS does not currently support the SYMV operation "
1200 << "for the \"double\" dataype";
1201 return false;
1202 }
1203
DoBlasSyr(Stream * stream,blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * a,int lda)1204 bool ROCMBlas::DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
1205 float alpha, const DeviceMemory<float> &x, int incx,
1206 DeviceMemory<float> *a, int lda) {
1207 return DoBlasInternal(wrap::rocblas_ssyr, stream,
1208 true /* = pointer_mode_host */,
1209 ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1210 GpuMemoryMutable(a), lda);
1211 }
1212
DoBlasSyr(Stream * stream,blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * a,int lda)1213 bool ROCMBlas::DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
1214 double alpha, const DeviceMemory<double> &x, int incx,
1215 DeviceMemory<double> *a, int lda) {
1216 return DoBlasInternal(wrap::rocblas_dsyr, stream,
1217 true /* = pointer_mode_host */,
1218 ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1219 GpuMemoryMutable(a), lda);
1220 }
1221
DoBlasSyr2(Stream * stream,blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * a,int lda)1222 bool ROCMBlas::DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1223 float alpha, const DeviceMemory<float> &x, int incx,
1224 const DeviceMemory<float> &y, int incy,
1225 DeviceMemory<float> *a, int lda) {
1226 LOG(ERROR) << "rocBLAS does not currently support the SYR2 operation "
1227 << "for the \"float\" dataype";
1228 return false;
1229 }
1230
DoBlasSyr2(Stream * stream,blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * a,int lda)1231 bool ROCMBlas::DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1232 double alpha, const DeviceMemory<double> &x, int incx,
1233 const DeviceMemory<double> &y, int incy,
1234 DeviceMemory<double> *a, int lda) {
1235 LOG(ERROR) << "rocBLAS does not currently support the SYR2 operation "
1236 << "for the \"double\" dataype";
1237 return false;
1238 }
1239
DoBlasTbmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)1240 bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
1241 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1242 uint64 k, const DeviceMemory<float> &a, int lda,
1243 DeviceMemory<float> *x, int incx) {
1244 LOG(ERROR) << "rocBLAS does not currently support the TBMV operation "
1245 << "for the \"float\" dataype";
1246 return false;
1247 }
1248
DoBlasTbmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)1249 bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
1250 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1251 uint64 k, const DeviceMemory<double> &a, int lda,
1252 DeviceMemory<double> *x, int incx) {
1253 LOG(ERROR) << "rocBLAS does not currently support the TBMV operation "
1254 << "for the \"double\" dataype";
1255 return false;
1256 }
1257
DoBlasTbmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)1258 bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
1259 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1260 uint64 k, const DeviceMemory<std::complex<float>> &a,
1261 int lda, DeviceMemory<std::complex<float>> *x,
1262 int incx) {
1263 LOG(ERROR) << "rocBLAS does not currently support the TBMV operation "
1264 << "for the \"complex<float>\" dataype";
1265 return false;
1266 }
1267
DoBlasTbmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)1268 bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
1269 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1270 uint64 k, const DeviceMemory<std::complex<double>> &a,
1271 int lda, DeviceMemory<std::complex<double>> *x,
1272 int incx) {
1273 LOG(ERROR) << "rocBLAS does not currently support the TBMV operation "
1274 << "for the \"complex<double>\" dataype";
1275 return false;
1276 }
1277
DoBlasTbsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)1278 bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
1279 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1280 uint64 k, const DeviceMemory<float> &a, int lda,
1281 DeviceMemory<float> *x, int incx) {
1282 LOG(ERROR) << "rocBLAS does not currently support the TBSV operation "
1283 << "for the \"float\" dataype";
1284 return false;
1285 }
1286
DoBlasTbsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)1287 bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
1288 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1289 uint64 k, const DeviceMemory<double> &a, int lda,
1290 DeviceMemory<double> *x, int incx) {
1291 LOG(ERROR) << "rocBLAS does not currently support the TBSV operation "
1292 << "for the \"double\" dataype";
1293 return false;
1294 }
1295
DoBlasTbsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)1296 bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
1297 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1298 uint64 k, const DeviceMemory<std::complex<float>> &a,
1299 int lda, DeviceMemory<std::complex<float>> *x,
1300 int incx) {
1301 LOG(ERROR) << "rocBLAS does not currently support the TBSV operation "
1302 << "for the \"complex<float>\" dataype";
1303 return false;
1304 }
1305
DoBlasTbsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)1306 bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
1307 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1308 uint64 k, const DeviceMemory<std::complex<double>> &a,
1309 int lda, DeviceMemory<std::complex<double>> *x,
1310 int incx) {
1311 LOG(ERROR) << "rocBLAS does not currently support the TBSV operation "
1312 << "for the \"complex<double>\" dataype";
1313 return false;
1314 }
1315
DoBlasTpmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & ap,DeviceMemory<float> * x,int incx)1316 bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
1317 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1318 const DeviceMemory<float> &ap, DeviceMemory<float> *x,
1319 int incx) {
1320 LOG(ERROR) << "rocBLAS does not currently support the TPMV operation "
1321 << "for the \"float\" dataype";
1322 return false;
1323 }
1324
DoBlasTpmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & ap,DeviceMemory<double> * x,int incx)1325 bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
1326 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1327 const DeviceMemory<double> &ap,
1328 DeviceMemory<double> *x, int incx) {
1329 LOG(ERROR) << "rocBLAS does not currently support the TPMV operation "
1330 << "for the \"double\" dataype";
1331 return false;
1332 }
1333
DoBlasTpmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & ap,DeviceMemory<std::complex<float>> * x,int incx)1334 bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
1335 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1336 const DeviceMemory<std::complex<float>> &ap,
1337 DeviceMemory<std::complex<float>> *x, int incx) {
1338 LOG(ERROR) << "rocBLAS does not currently support the TPMV operation "
1339 << "for the \"complex<float>\" dataype";
1340 return false;
1341 }
1342
DoBlasTpmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & ap,DeviceMemory<std::complex<double>> * x,int incx)1343 bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
1344 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1345 const DeviceMemory<std::complex<double>> &ap,
1346 DeviceMemory<std::complex<double>> *x, int incx) {
1347 LOG(ERROR) << "rocBLAS does not currently support the TPMV operation "
1348 << "for the \"complex<double>\" dataype";
1349 return false;
1350 }
1351
DoBlasTpsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & ap,DeviceMemory<float> * x,int incx)1352 bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
1353 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1354 const DeviceMemory<float> &ap, DeviceMemory<float> *x,
1355 int incx) {
1356 LOG(ERROR) << "rocBLAS does not currently support the TPSV operation "
1357 << "for the \"float\" dataype";
1358 return false;
1359 }
1360
DoBlasTpsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & ap,DeviceMemory<double> * x,int incx)1361 bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
1362 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1363 const DeviceMemory<double> &ap,
1364 DeviceMemory<double> *x, int incx) {
1365 LOG(ERROR) << "rocBLAS does not currently support the TPSV operation "
1366 << "for the \"double\" dataype";
1367 return false;
1368 }
1369
DoBlasTpsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & ap,DeviceMemory<std::complex<float>> * x,int incx)1370 bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
1371 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1372 const DeviceMemory<std::complex<float>> &ap,
1373 DeviceMemory<std::complex<float>> *x, int incx) {
1374 LOG(ERROR) << "rocBLAS does not currently support the TPSV operation "
1375 << "for the \"complex<float>\" dataype";
1376 return false;
1377 }
1378
DoBlasTpsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & ap,DeviceMemory<std::complex<double>> * x,int incx)1379 bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
1380 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1381 const DeviceMemory<std::complex<double>> &ap,
1382 DeviceMemory<std::complex<double>> *x, int incx) {
1383 LOG(ERROR) << "rocBLAS does not currently support the TPSV operation "
1384 << "for the \"complex<double>\" dataype";
1385 return false;
1386 }
1387
DoBlasTrmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)1388 bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
1389 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1390 const DeviceMemory<float> &a, int lda,
1391 DeviceMemory<float> *x, int incx) {
1392 LOG(ERROR) << "rocBLAS does not currently support the TRMV operation "
1393 << "for the \"float\" dataype";
1394 return false;
1395 }
1396
DoBlasTrmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)1397 bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
1398 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1399 const DeviceMemory<double> &a, int lda,
1400 DeviceMemory<double> *x, int incx) {
1401 LOG(ERROR) << "rocBLAS does not currently support the TRMV operation "
1402 << "for the \"double\" dataype";
1403 return false;
1404 }
1405
DoBlasTrmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)1406 bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
1407 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1408 const DeviceMemory<std::complex<float>> &a, int lda,
1409 DeviceMemory<std::complex<float>> *x, int incx) {
1410 LOG(ERROR) << "rocBLAS does not currently support the TRMV operation "
1411 << "for the \"complex<float>\" dataype";
1412 return false;
1413 }
1414
DoBlasTrmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)1415 bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
1416 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1417 const DeviceMemory<std::complex<double>> &a, int lda,
1418 DeviceMemory<std::complex<double>> *x, int incx) {
1419 LOG(ERROR) << "rocBLAS does not currently support the TRMV operation "
1420 << "for the \"complex<double>\" dataype";
1421 return false;
1422 }
1423
DoBlasTrsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)1424 bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1425 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1426 const DeviceMemory<float> &a, int lda,
1427 DeviceMemory<float> *x, int incx) {
1428 LOG(ERROR) << "rocBLAS does not currently support the TRSV operation "
1429 << "for the \"float\" dataype";
1430 return false;
1431 }
1432
DoBlasTrsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)1433 bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1434 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1435 const DeviceMemory<double> &a, int lda,
1436 DeviceMemory<double> *x, int incx) {
1437 LOG(ERROR) << "rocBLAS does not currently support the TRSV operation "
1438 << "for the \"double\" dataype";
1439 return false;
1440 }
1441
DoBlasTrsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)1442 bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1443 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1444 const DeviceMemory<std::complex<float>> &a, int lda,
1445 DeviceMemory<std::complex<float>> *x, int incx) {
1446 LOG(ERROR) << "rocBLAS does not currently support the TRSV operation "
1447 << "for the \"complex<float>\" dataype";
1448 return false;
1449 }
1450
DoBlasTrsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)1451 bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1452 blas::Transpose trans, blas::Diagonal diag, uint64 n,
1453 const DeviceMemory<std::complex<double>> &a, int lda,
1454 DeviceMemory<std::complex<double>> *x, int incx) {
1455 LOG(ERROR) << "rocBLAS does not currently support the TRSV operation "
1456 << "for the \"complex<double>\" dataype";
1457 return false;
1458 }
1459
DoBlasGemm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,float beta,DeviceMemory<Eigen::half> * c,int ldc)1460 bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
1461 blas::Transpose transb, uint64 m, uint64 n, uint64 k,
1462 float alpha, const DeviceMemory<Eigen::half> &a,
1463 int lda, const DeviceMemory<Eigen::half> &b, int ldb,
1464 float beta, DeviceMemory<Eigen::half> *c, int ldc) {
1465 VLOG(1) << port::Printf(
1466 "doing rocBLAS SGEMM: at=%d bt=%d m=%llu n=%llu "
1467 "k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
1468 "c=%p ldc=%d",
1469 static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
1470 a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
1471 if (transa == blas::Transpose::kNoTranspose) {
1472 if (lda < static_cast<int64>(m)) {
1473 LOG(WARNING) << "GEMM lda was smaller than m (no transpose case); "
1474 "precondition violation";
1475 }
1476 } else {
1477 if (lda < static_cast<int64>(k)) {
1478 LOG(WARNING) << "GEMM lda (" << lda << ") was smaller than k (" << k
1479 << ") (transpose case); precondition violation";
1480 }
1481 }
1482 if (transb == blas::Transpose::kNoTranspose) {
1483 if (ldb < static_cast<int64>(k)) {
1484 LOG(WARNING) << "GEMM ldb (" << ldb << ") was smaller than k (" << k
1485 << ") (no transpose case); precondition violation";
1486 }
1487 } else {
1488 if (ldb < static_cast<int64>(n)) {
1489 LOG(WARNING) << "GEMM ldb was smaller than n (transpose case); "
1490 "precondition violation";
1491 }
1492 }
1493 const Eigen::half alpha_half(alpha);
1494 const Eigen::half beta_half(beta);
1495 return DoBlasInternal(
1496 wrap::rocblas_hgemm, stream, true /* = pointer_mode_host */,
1497 ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
1498 reinterpret_cast<const rocblas_half *>(&alpha_half),
1499 reinterpret_cast<const rocblas_half *>(GpuMemory(a)), lda,
1500 reinterpret_cast<const rocblas_half *>(GpuMemory(b)), ldb,
1501 reinterpret_cast<const rocblas_half *>(&beta_half),
1502 reinterpret_cast<rocblas_half *>(GpuMemoryMutable(c)), ldc);
1503 }
1504
DoBlasGemm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)1505 bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
1506 blas::Transpose transb, uint64 m, uint64 n, uint64 k,
1507 float alpha, const DeviceMemory<float> &a, int lda,
1508 const DeviceMemory<float> &b, int ldb, float beta,
1509 DeviceMemory<float> *c, int ldc) {
1510 VLOG(1) << port::Printf(
1511 "doing rocBLAS SGEMM: at=%d bt=%d m=%llu n=%llu "
1512 "k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
1513 "c=%p ldc=%d",
1514 static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
1515 a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
1516 if (transa == blas::Transpose::kNoTranspose) {
1517 if (lda < static_cast<int64>(m)) {
1518 LOG(WARNING) << "GEMM lda was smaller than m (no transpose case); "
1519 "precondition violation";
1520 }
1521 } else {
1522 if (lda < static_cast<int64>(k)) {
1523 LOG(WARNING) << "GEMM lda (" << lda << ") was smaller than k (" << k
1524 << ") (transpose case); precondition violation";
1525 }
1526 }
1527 if (transb == blas::Transpose::kNoTranspose) {
1528 if (ldb < static_cast<int64>(k)) {
1529 LOG(WARNING) << "GEMM ldb (" << ldb << ") was smaller than k (" << k
1530 << ") (no transpose case); precondition violation";
1531 }
1532 } else {
1533 if (ldb < static_cast<int64>(n)) {
1534 LOG(WARNING) << "GEMM ldb was smaller than n (transpose case); "
1535 "precondition violation";
1536 }
1537 }
1538 return DoBlasInternal(
1539 wrap::rocblas_sgemm, stream, true /* = pointer_mode_host */,
1540 ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, &alpha,
1541 GpuMemory(a), lda, GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc);
1542 }
1543
DoBlasGemm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)1544 bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
1545 blas::Transpose transb, uint64 m, uint64 n, uint64 k,
1546 double alpha, const DeviceMemory<double> &a, int lda,
1547 const DeviceMemory<double> &b, int ldb, double beta,
1548 DeviceMemory<double> *c, int ldc) {
1549 return DoBlasInternal(
1550 wrap::rocblas_dgemm, stream, true /* = pointer_mode_host */,
1551 ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, &alpha,
1552 GpuMemory(a), lda, GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc);
1553 }
1554
DoBlasGemm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)1555 bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
1556 blas::Transpose transb, uint64 m, uint64 n, uint64 k,
1557 std::complex<float> alpha,
1558 const DeviceMemory<std::complex<float>> &a, int lda,
1559 const DeviceMemory<std::complex<float>> &b, int ldb,
1560 std::complex<float> beta,
1561 DeviceMemory<std::complex<float>> *c, int ldc) {
1562 LOG(ERROR) << "rocBLAS does not currently support the GEMM operation "
1563 << "for the \"complex<float>\" dataype";
1564 return false;
1565 }
1566
DoBlasGemm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)1567 bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
1568 blas::Transpose transb, uint64 m, uint64 n, uint64 k,
1569 std::complex<double> alpha,
1570 const DeviceMemory<std::complex<double>> &a, int lda,
1571 const DeviceMemory<std::complex<double>> &b, int ldb,
1572 std::complex<double> beta,
1573 DeviceMemory<std::complex<double>> *c, int ldc) {
1574 LOG(ERROR) << "rocBLAS does not currently support the GEMM operation "
1575 << "for the \"complex<double>\" dataype";
1576 return false;
1577 }
1578
DoBlasGemvWithProfiling(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy,blas::ProfileResult * output_profile_result)1579 bool ROCMBlas::DoBlasGemvWithProfiling(
1580 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha,
1581 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
1582 int incx, float beta, DeviceMemory<float> *y, int incy,
1583 blas::ProfileResult *output_profile_result) {
1584 return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
1585 incx, beta, y, incy,
1586 output_profile_result);
1587 }
1588
DoBlasGemvWithProfiling(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy,blas::ProfileResult * output_profile_result)1589 bool ROCMBlas::DoBlasGemvWithProfiling(
1590 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha,
1591 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
1592 int incx, double beta, DeviceMemory<double> *y, int incy,
1593 blas::ProfileResult *output_profile_result) {
1594 return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
1595 incx, beta, y, incy,
1596 output_profile_result);
1597 }
1598
DoBlasGemvWithProfiling(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy,blas::ProfileResult * output_profile_result)1599 bool ROCMBlas::DoBlasGemvWithProfiling(
1600 Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
1601 std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a,
1602 int lda, const DeviceMemory<std::complex<float>> &x, int incx,
1603 std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
1604 blas::ProfileResult *output_profile_result) {
1605 return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
1606 incx, beta, y, incy,
1607 output_profile_result);
1608 }
1609
DoBlasGemvWithProfiling(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy,blas::ProfileResult * output_profile_result)1610 bool ROCMBlas::DoBlasGemvWithProfiling(
1611 Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
1612 std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a,
1613 int lda, const DeviceMemory<std::complex<double>> &x, int incx,
1614 std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy,
1615 blas::ProfileResult *output_profile_result) {
1616 return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
1617 incx, beta, y, incy,
1618 output_profile_result);
1619 }
1620
DoBlasGemmWithProfiling(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,float beta,DeviceMemory<Eigen::half> * c,int ldc,blas::ProfileResult * output_profile_result)1621 bool ROCMBlas::DoBlasGemmWithProfiling(
1622 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1623 uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
1624 int lda, const DeviceMemory<Eigen::half> &b, int ldb, float beta,
1625 DeviceMemory<Eigen::half> *c, int ldc,
1626 blas::ProfileResult *output_profile_result) {
1627 return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1628 lda, b, ldb, beta, c, ldc,
1629 output_profile_result);
1630 }
1631
DoBlasGemmWithProfiling(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc,blas::ProfileResult * output_profile_result)1632 bool ROCMBlas::DoBlasGemmWithProfiling(
1633 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1634 uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
1635 const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
1636 int ldc, blas::ProfileResult *output_profile_result) {
1637 return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1638 lda, b, ldb, beta, c, ldc,
1639 output_profile_result);
1640 }
1641
DoBlasGemmWithProfiling(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc,blas::ProfileResult * output_profile_result)1642 bool ROCMBlas::DoBlasGemmWithProfiling(
1643 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1644 uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
1645 const DeviceMemory<double> &b, int ldb, double beta,
1646 DeviceMemory<double> *c, int ldc,
1647 blas::ProfileResult *output_profile_result) {
1648 return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1649 lda, b, ldb, beta, c, ldc,
1650 output_profile_result);
1651 }
1652
DoBlasGemmWithProfiling(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc,blas::ProfileResult * output_profile_result)1653 bool ROCMBlas::DoBlasGemmWithProfiling(
1654 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1655 uint64 n, uint64 k, std::complex<float> alpha,
1656 const DeviceMemory<std::complex<float>> &a, int lda,
1657 const DeviceMemory<std::complex<float>> &b, int ldb,
1658 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
1659 blas::ProfileResult *output_profile_result) {
1660 return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1661 lda, b, ldb, beta, c, ldc,
1662 output_profile_result);
1663 }
1664
DoBlasGemmWithProfiling(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc,blas::ProfileResult * output_profile_result)1665 bool ROCMBlas::DoBlasGemmWithProfiling(
1666 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1667 uint64 n, uint64 k, std::complex<double> alpha,
1668 const DeviceMemory<std::complex<double>> &a, int lda,
1669 const DeviceMemory<std::complex<double>> &b, int ldb,
1670 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
1671 blas::ProfileResult *output_profile_result) {
1672 return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1673 lda, b, ldb, beta, c, ldc,
1674 output_profile_result);
1675 }
1676
1677 template <typename T>
DoBlasGemvWithProfilingImpl(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,const T & alpha,const DeviceMemory<T> & a,int lda,const DeviceMemory<T> & x,int incx,const T & beta,DeviceMemory<T> * y,int incy,blas::ProfileResult * output_profile_result)1678 bool ROCMBlas::DoBlasGemvWithProfilingImpl(
1679 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, const T &alpha,
1680 const DeviceMemory<T> &a, int lda, const DeviceMemory<T> &x, int incx,
1681 const T &beta, DeviceMemory<T> *y, int incy,
1682 blas::ProfileResult *output_profile_result) {
1683 // ROCM TODO: properly implement the interface
1684 return false;
1685 }
1686
1687 template <typename T, typename ParamType>
DoBlasGemmWithProfilingImpl(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const ParamType & alpha,const DeviceMemory<T> & a,int lda,const DeviceMemory<T> & b,int ldb,const ParamType & beta,DeviceMemory<T> * c,int ldc,blas::ProfileResult * output_profile_result)1688 bool ROCMBlas::DoBlasGemmWithProfilingImpl(
1689 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1690 uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a,
1691 int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta,
1692 DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result) {
1693 // ROCM TODO: properly implement the interface
1694 return false;
1695 }
1696
1697 template <typename InT, typename OutT, typename CompT>
DoBlasGemmWithAlgorithmImpl(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const CompT & alpha,const DeviceMemory<InT> & a,int lda,const DeviceMemory<InT> & b,int ldb,const CompT & beta,DeviceMemory<OutT> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)1698 bool ROCMBlas::DoBlasGemmWithAlgorithmImpl(
1699 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1700 uint64 n, uint64 k, const CompT &alpha, const DeviceMemory<InT> &a, int lda,
1701 const DeviceMemory<InT> &b, int ldb, const CompT &beta,
1702 DeviceMemory<OutT> *c, int ldc, blas::ComputationType computation_type,
1703 blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
1704 // ROCM TODO: properly implement the interface
1705 return false;
1706 }
1707
GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> * out_algorithms)1708 bool ROCMBlas::GetBlasGemmAlgorithms(
1709 std::vector<blas::AlgorithmType> *out_algorithms) {
1710 // ROCM TODO: properly implement the interface
1711 return true;
1712 }
1713
DoBlasGemmWithAlgorithm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<int> & alpha,const DeviceMemory<int8> & a,int lda,const DeviceMemory<int8> & b,int ldb,const HostOrDeviceScalar<int> & beta,DeviceMemory<int32> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)1714 bool ROCMBlas::DoBlasGemmWithAlgorithm(
1715 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1716 uint64 n, uint64 k, const HostOrDeviceScalar<int> &alpha,
1717 const DeviceMemory<int8> &a, int lda, const DeviceMemory<int8> &b, int ldb,
1718 const HostOrDeviceScalar<int> &beta, DeviceMemory<int32> *c, int ldc,
1719 blas::ComputationType computation_type, blas::AlgorithmType algorithm,
1720 blas::ProfileResult *output_profile_result) {
1721 LOG(ERROR)
1722 << "rocBLAS does not currently support the GEMMwithAlgorithm operation "
1723 << "for the \"int8\" dataype";
1724 return false;
1725 }
1726
DoBlasGemmWithAlgorithm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<Eigen::half> & alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,const HostOrDeviceScalar<Eigen::half> & beta,DeviceMemory<Eigen::half> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)1727 bool ROCMBlas::DoBlasGemmWithAlgorithm(
1728 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1729 uint64 n, uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha,
1730 const DeviceMemory<Eigen::half> &a, int lda,
1731 const DeviceMemory<Eigen::half> &b, int ldb,
1732 const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c,
1733 int ldc, blas::ComputationType computation_type,
1734 blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
1735 LOG(ERROR)
1736 << "rocBLAS does not currently support the GEMMwithAlgorithm operation "
1737 << "for the \"half\" dataype";
1738 return false;
1739 }
1740
DoBlasGemmWithAlgorithm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<float> & alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,const HostOrDeviceScalar<float> & beta,DeviceMemory<float> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)1741 bool ROCMBlas::DoBlasGemmWithAlgorithm(
1742 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1743 uint64 n, uint64 k, const HostOrDeviceScalar<float> &alpha,
1744 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,
1745 int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c,
1746 int ldc, blas::ComputationType computation_type,
1747 blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
1748 LOG(ERROR)
1749 << "rocBLAS does not currently support the GEMMwithAlgorithm operation "
1750 << "for the \"float\" dataype";
1751 return false;
1752 }
1753
DoBlasGemmWithAlgorithm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<double> & alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,const HostOrDeviceScalar<double> & beta,DeviceMemory<double> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)1754 bool ROCMBlas::DoBlasGemmWithAlgorithm(
1755 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1756 uint64 n, uint64 k, const HostOrDeviceScalar<double> &alpha,
1757 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,
1758 int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c,
1759 int ldc, blas::ComputationType computation_type,
1760 blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
1761 LOG(ERROR)
1762 << "rocBLAS does not currently support the GEMMwithAlgorithm operation "
1763 << "for the \"double\" dataype";
1764 return false;
1765 }
1766
DoBlasGemmWithAlgorithm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<std::complex<float>> & alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,const HostOrDeviceScalar<std::complex<float>> & beta,DeviceMemory<std::complex<float>> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)1767 bool ROCMBlas::DoBlasGemmWithAlgorithm(
1768 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1769 uint64 n, uint64 k, const HostOrDeviceScalar<std::complex<float>> &alpha,
1770 const DeviceMemory<std::complex<float>> &a, int lda,
1771 const DeviceMemory<std::complex<float>> &b, int ldb,
1772 const HostOrDeviceScalar<std::complex<float>> &beta,
1773 DeviceMemory<std::complex<float>> *c, int ldc,
1774 blas::ComputationType computation_type, blas::AlgorithmType algorithm,
1775 blas::ProfileResult *output_profile_result) {
1776 LOG(ERROR)
1777 << "rocBLAS does not currently support the GEMMwithAlgorithm operation "
1778 << "for the \"complex<float>\" dataype";
1779 return false;
1780 }
1781
DoBlasGemmWithAlgorithm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<std::complex<double>> & alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,const HostOrDeviceScalar<std::complex<double>> & beta,DeviceMemory<std::complex<double>> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)1782 bool ROCMBlas::DoBlasGemmWithAlgorithm(
1783 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1784 uint64 n, uint64 k, const HostOrDeviceScalar<std::complex<double>> &alpha,
1785 const DeviceMemory<std::complex<double>> &a, int lda,
1786 const DeviceMemory<std::complex<double>> &b, int ldb,
1787 const HostOrDeviceScalar<std::complex<double>> &beta,
1788 DeviceMemory<std::complex<double>> *c, int ldc,
1789 blas::ComputationType computation_type, blas::AlgorithmType algorithm,
1790 blas::ProfileResult *output_profile_result) {
1791 LOG(ERROR)
1792 << "rocBLAS does not currently support the GEMMwithAlgorithm operation "
1793 << "for the \"complex<double>\" dataype";
1794 return false;
1795 }
1796
1797 template <typename T>
1798 struct EigenHalfToRocBlasHalf {
1799 using type = T;
1800 };
1801
1802 template <>
1803 struct EigenHalfToRocBlasHalf<Eigen::half> {
1804 using type = rocblas_half;
1805 };
1806
1807 template <typename T, typename FuncT>
DoBlasGemmBatchedInternal(FuncT rocblas_func,Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,T alpha,const port::ArraySlice<DeviceMemory<T> * > & a_ptrs_to_wrappers,int lda,const port::ArraySlice<DeviceMemory<T> * > & b_ptrs_to_wrappers,int ldb,T beta,const port::ArraySlice<DeviceMemory<T> * > & c_ptrs_to_wrappers,int ldc,int batch_count,ScratchAllocator * scratch_allocator)1808 port::Status ROCMBlas::DoBlasGemmBatchedInternal(
1809 FuncT rocblas_func, Stream *stream, blas::Transpose transa,
1810 blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha,
1811 const port::ArraySlice<DeviceMemory<T> *> &a_ptrs_to_wrappers, int lda,
1812 const port::ArraySlice<DeviceMemory<T> *> &b_ptrs_to_wrappers, int ldb,
1813 T beta, const port::ArraySlice<DeviceMemory<T> *> &c_ptrs_to_wrappers,
1814 int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
1815 // MAPPED_T will be same as T for all types except Eigen::Half
1816 // for T = Eigen::half, MAPPED_T = rocblas_half
1817 using MAPPED_T = typename EigenHalfToRocBlasHalf<T>::type;
1818
1819 // Alocate local vectors to hold device pointers to matrices
1820 std::vector<MAPPED_T *> a_raw_ptrs, b_raw_ptrs, c_raw_ptrs;
1821 for (int i = 0; i < batch_count; ++i) {
1822 // static_cast does work when converting Eigen::half* to rocblas_half*,
1823 // hence the use od reinterpret_cast
1824 a_raw_ptrs.push_back(
1825 reinterpret_cast<MAPPED_T *>(a_ptrs_to_wrappers[i]->opaque()));
1826 b_raw_ptrs.push_back(
1827 reinterpret_cast<MAPPED_T *>(b_ptrs_to_wrappers[i]->opaque()));
1828 c_raw_ptrs.push_back(
1829 reinterpret_cast<MAPPED_T *>(c_ptrs_to_wrappers[i]->opaque()));
1830 }
1831
1832 // batch_count <= 1 is base case, no definable matrix stride, set it same as
1833 // ld*
1834 long long bsa = lda;
1835 long long bsb = ldb;
1836 long long bsc = ldc;
1837 bool bsa_is_constant = true;
1838 bool bsb_is_constant = true;
1839 bool bsc_is_constant = true;
1840
1841 if (batch_count > 1) {
1842 // Remember first stride; if any other stride is different that this one,
1843 // KABLAM
1844 bsa = a_raw_ptrs[1] - a_raw_ptrs[0];
1845 bsb = b_raw_ptrs[1] - b_raw_ptrs[0];
1846 bsc = c_raw_ptrs[1] - c_raw_ptrs[0];
1847
1848 // Loop to verify that batched strides are constant
1849 // All the test cases from batch_matmul_op_test.py seem to satisfy this
1850 // requirement of a constant stride. If this can be proven globally, then
1851 // this loop check can be safely removed
1852 for (int i = 1; i < batch_count - 1; ++i) {
1853 long long iterative_bsa = a_raw_ptrs[i + 1] - a_raw_ptrs[i];
1854 if (iterative_bsa != bsa) {
1855 bsa_is_constant = false;
1856 break;
1857 }
1858
1859 long long iterative_bsb = b_raw_ptrs[i + 1] - b_raw_ptrs[i];
1860 if (iterative_bsb != bsb) {
1861 bsb_is_constant = false;
1862 break;
1863 }
1864
1865 long long iterative_bsc = c_raw_ptrs[i + 1] - c_raw_ptrs[i];
1866 if (iterative_bsc != bsc) {
1867 bsc_is_constant = false;
1868 break;
1869 }
1870 }
1871 }
1872
1873 assert(!(ldc < m || bsc < ldc * n));
1874
1875 if (ROCMBlasTranspose(transa) == rocblas_operation_none)
1876 assert(!(lda < m || bsa < lda * k));
1877 else
1878 assert(!(lda < k || bsa < lda * m));
1879
1880 if (ROCMBlasTranspose(transb) == rocblas_operation_none)
1881 assert(!(ldb < k || bsb < ldb * n));
1882 else
1883 assert(!(ldb < n || bsc < ldc * k));
1884
1885 MAPPED_T *alpha_ptr = reinterpret_cast<MAPPED_T *>(&alpha);
1886 MAPPED_T *beta_ptr = reinterpret_cast<MAPPED_T *>(&beta);
1887
1888 if (bsa_is_constant && bsb_is_constant && bsc_is_constant) {
1889 bool ok = DoBlasInternal(
1890 rocblas_func, stream, true /* = pointer_mode_host */,
1891 ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
1892 GpuComplex(alpha_ptr), a_raw_ptrs[0], lda, bsa, b_raw_ptrs[0], ldb, bsb,
1893 GpuComplex(beta_ptr), c_raw_ptrs[0], ldc, bsc, batch_count);
1894
1895 if (ok) {
1896 return port::Status::OK();
1897 }
1898 }
1899
1900 return port::Status(port::error::INTERNAL,
1901 "failed BLAS call, see log for details");
1902 }
1903
DoBlasGemmBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<Eigen::half> * > & a,int lda,const port::ArraySlice<DeviceMemory<Eigen::half> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<Eigen::half> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)1904 bool ROCMBlas::DoBlasGemmBatched(
1905 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1906 uint64 n, uint64 k, float alpha,
1907 const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
1908 const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
1909 const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
1910 int batch_count, ScratchAllocator *scratch_allocator) {
1911 const Eigen::half alpha_half(alpha);
1912 const Eigen::half beta_half(beta);
1913
1914 port::Status status = DoBlasGemmBatchedInternal(
1915 wrap::rocblas_hgemm_strided_batched, stream, transa, transb, m, n, k,
1916 alpha_half, a, lda, b, ldb, beta_half, c, ldc, batch_count,
1917 scratch_allocator);
1918 if (!status.ok()) {
1919 LOG(ERROR) << status;
1920 }
1921
1922 return status.ok();
1923 }
1924
DoBlasGemmBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<float> * > & a_array,int lda,const port::ArraySlice<DeviceMemory<float> * > & b_array,int ldb,float beta,const port::ArraySlice<DeviceMemory<float> * > & c_array,int ldc,int batch_count,ScratchAllocator * scratch_allocator)1925 bool ROCMBlas::DoBlasGemmBatched(
1926 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1927 uint64 n, uint64 k, float alpha,
1928 const port::ArraySlice<DeviceMemory<float> *> &a_array, int lda,
1929 const port::ArraySlice<DeviceMemory<float> *> &b_array, int ldb, float beta,
1930 const port::ArraySlice<DeviceMemory<float> *> &c_array, int ldc,
1931 int batch_count, ScratchAllocator *scratch_allocator) {
1932 port::Status status = DoBlasGemmBatchedInternal(
1933 wrap::rocblas_sgemm_strided_batched, stream, transa, transb, m, n, k,
1934 alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count,
1935 scratch_allocator);
1936 if (!status.ok()) {
1937 LOG(ERROR) << status;
1938 }
1939 return status.ok();
1940 }
1941
DoBlasGemmBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const port::ArraySlice<DeviceMemory<double> * > & a_array,int lda,const port::ArraySlice<DeviceMemory<double> * > & b_array,int ldb,double beta,const port::ArraySlice<DeviceMemory<double> * > & c_array,int ldc,int batch_count,ScratchAllocator * scratch_allocator)1942 bool ROCMBlas::DoBlasGemmBatched(
1943 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1944 uint64 n, uint64 k, double alpha,
1945 const port::ArraySlice<DeviceMemory<double> *> &a_array, int lda,
1946 const port::ArraySlice<DeviceMemory<double> *> &b_array, int ldb,
1947 double beta, const port::ArraySlice<DeviceMemory<double> *> &c_array,
1948 int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
1949 port::Status status = DoBlasGemmBatchedInternal(
1950 wrap::rocblas_dgemm_strided_batched, stream, transa, transb, m, n, k,
1951 alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count,
1952 scratch_allocator);
1953 if (!status.ok()) {
1954 LOG(ERROR) << status;
1955 }
1956 return status.ok();
1957 }
1958
DoBlasGemmBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & a_array,int lda,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & b_array,int ldb,std::complex<float> beta,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & c_array,int ldc,int batch_count,ScratchAllocator * scratch_allocator)1959 bool ROCMBlas::DoBlasGemmBatched(
1960 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1961 uint64 n, uint64 k, std::complex<float> alpha,
1962 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a_array,
1963 int lda,
1964 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b_array,
1965 int ldb, std::complex<float> beta,
1966 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c_array,
1967 int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
1968 LOG(ERROR) << "rocBLAS does not currently support the GEMMBatched operation "
1969 << "for the \"complex<float>\" dataype";
1970 return false;
1971 }
1972
DoBlasGemmBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & a_array,int lda,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & b_array,int ldb,std::complex<double> beta,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & c_array,int ldc,int batch_count,ScratchAllocator * scratch_allocator)1973 bool ROCMBlas::DoBlasGemmBatched(
1974 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1975 uint64 n, uint64 k, std::complex<double> alpha,
1976 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a_array,
1977 int lda,
1978 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b_array,
1979 int ldb, std::complex<double> beta,
1980 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c_array,
1981 int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
1982 LOG(ERROR) << "rocBLAS does not currently support the GEMMBatched operation "
1983 << "for the \"complex<double>\" dataype";
1984 return false;
1985 }
1986
DoBlasHemm(Stream * stream,blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)1987 bool ROCMBlas::DoBlasHemm(Stream *stream, blas::Side side,
1988 blas::UpperLower uplo, uint64 m, uint64 n,
1989 std::complex<float> alpha,
1990 const DeviceMemory<std::complex<float>> &a, int lda,
1991 const DeviceMemory<std::complex<float>> &b, int ldb,
1992 std::complex<float> beta,
1993 DeviceMemory<std::complex<float>> *c, int ldc) {
1994 LOG(ERROR) << "rocBLAS does not currently support the HEMM operation "
1995 << "for the \"complex<float>\" dataype";
1996 return false;
1997 }
1998
DoBlasHemm(Stream * stream,blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)1999 bool ROCMBlas::DoBlasHemm(Stream *stream, blas::Side side,
2000 blas::UpperLower uplo, uint64 m, uint64 n,
2001 std::complex<double> alpha,
2002 const DeviceMemory<std::complex<double>> &a, int lda,
2003 const DeviceMemory<std::complex<double>> &b, int ldb,
2004 std::complex<double> beta,
2005 DeviceMemory<std::complex<double>> *c, int ldc) {
2006 LOG(ERROR) << "rocBLAS does not currently support the HEMM operation "
2007 << "for the \"complex<double>\" dataype";
2008 return false;
2009 }
2010
DoBlasHerk(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<std::complex<float>> & a,int lda,float beta,DeviceMemory<std::complex<float>> * c,int ldc)2011 bool ROCMBlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo,
2012 blas::Transpose trans, uint64 n, uint64 k,
2013 float alpha,
2014 const DeviceMemory<std::complex<float>> &a, int lda,
2015 float beta, DeviceMemory<std::complex<float>> *c,
2016 int ldc) {
2017 LOG(ERROR) << "rocBLAS does not currently support the HERK operation "
2018 << "for the \"complex<float>\" dataype";
2019 return false;
2020 }
2021
DoBlasHerk(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<std::complex<double>> & a,int lda,double beta,DeviceMemory<std::complex<double>> * c,int ldc)2022 bool ROCMBlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo,
2023 blas::Transpose trans, uint64 n, uint64 k,
2024 double alpha,
2025 const DeviceMemory<std::complex<double>> &a, int lda,
2026 double beta, DeviceMemory<std::complex<double>> *c,
2027 int ldc) {
2028 LOG(ERROR) << "rocBLAS does not currently support the HERK operation "
2029 << "for the \"complex<double>\" dataype";
2030 return false;
2031 }
2032
DoBlasHer2k(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,float beta,DeviceMemory<std::complex<float>> * c,int ldc)2033 bool ROCMBlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
2034 blas::Transpose trans, uint64 n, uint64 k,
2035 std::complex<float> alpha,
2036 const DeviceMemory<std::complex<float>> &a, int lda,
2037 const DeviceMemory<std::complex<float>> &b, int ldb,
2038 float beta, DeviceMemory<std::complex<float>> *c,
2039 int ldc) {
2040 LOG(ERROR) << "rocBLAS does not currently support the HER2K operation "
2041 << "for the \"complex<float>\" dataype";
2042 return false;
2043 }
2044
DoBlasHer2k(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,double beta,DeviceMemory<std::complex<double>> * c,int ldc)2045 bool ROCMBlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
2046 blas::Transpose trans, uint64 n, uint64 k,
2047 std::complex<double> alpha,
2048 const DeviceMemory<std::complex<double>> &a, int lda,
2049 const DeviceMemory<std::complex<double>> &b, int ldb,
2050 double beta, DeviceMemory<std::complex<double>> *c,
2051 int ldc) {
2052 LOG(ERROR) << "rocBLAS does not currently support the HER2K operation "
2053 << "for the \"complex<double>\" dataype";
2054 return false;
2055 }
2056
DoBlasSymm(Stream * stream,blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)2057 bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side,
2058 blas::UpperLower uplo, uint64 m, uint64 n,
2059 float alpha, const DeviceMemory<float> &a, int lda,
2060 const DeviceMemory<float> &b, int ldb, float beta,
2061 DeviceMemory<float> *c, int ldc) {
2062 LOG(ERROR) << "rocBLAS does not currently support the SYMM operation "
2063 << "for the \"float\" dataype";
2064 return false;
2065 }
2066
DoBlasSymm(Stream * stream,blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)2067 bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side,
2068 blas::UpperLower uplo, uint64 m, uint64 n,
2069 double alpha, const DeviceMemory<double> &a, int lda,
2070 const DeviceMemory<double> &b, int ldb, double beta,
2071 DeviceMemory<double> *c, int ldc) {
2072 LOG(ERROR) << "rocBLAS does not currently support the SYMM operation "
2073 << "for the \"double\" dataype";
2074 return false;
2075 }
2076
DoBlasSymm(Stream * stream,blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)2077 bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side,
2078 blas::UpperLower uplo, uint64 m, uint64 n,
2079 std::complex<float> alpha,
2080 const DeviceMemory<std::complex<float>> &a, int lda,
2081 const DeviceMemory<std::complex<float>> &b, int ldb,
2082 std::complex<float> beta,
2083 DeviceMemory<std::complex<float>> *c, int ldc) {
2084 LOG(ERROR) << "rocBLAS does not currently support the SYMM operation "
2085 << "for the \"complex<float>\" dataype";
2086 return false;
2087 }
2088
DoBlasSymm(Stream * stream,blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)2089 bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side,
2090 blas::UpperLower uplo, uint64 m, uint64 n,
2091 std::complex<double> alpha,
2092 const DeviceMemory<std::complex<double>> &a, int lda,
2093 const DeviceMemory<std::complex<double>> &b, int ldb,
2094 std::complex<double> beta,
2095 DeviceMemory<std::complex<double>> *c, int ldc) {
2096 LOG(ERROR) << "rocBLAS does not currently support the SYMM operation "
2097 << "for the \"complex<double>\" dataype";
2098 return false;
2099 }
2100
DoBlasSyrk(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,float beta,DeviceMemory<float> * c,int ldc)2101 bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
2102 blas::Transpose trans, uint64 n, uint64 k,
2103 float alpha, const DeviceMemory<float> &a, int lda,
2104 float beta, DeviceMemory<float> *c, int ldc) {
2105 LOG(ERROR) << "rocBLAS does not currently support the SYRK operation "
2106 << "for the \"float\" dataype";
2107 return false;
2108 }
2109
DoBlasSyrk(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,double beta,DeviceMemory<double> * c,int ldc)2110 bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
2111 blas::Transpose trans, uint64 n, uint64 k,
2112 double alpha, const DeviceMemory<double> &a, int lda,
2113 double beta, DeviceMemory<double> *c, int ldc) {
2114 LOG(ERROR) << "rocBLAS does not currently support the SYRK operation "
2115 << "for the \"double\" dataype";
2116 return false;
2117 }
2118
DoBlasSyrk(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)2119 bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
2120 blas::Transpose trans, uint64 n, uint64 k,
2121 std::complex<float> alpha,
2122 const DeviceMemory<std::complex<float>> &a, int lda,
2123 std::complex<float> beta,
2124 DeviceMemory<std::complex<float>> *c, int ldc) {
2125 LOG(ERROR) << "rocBLAS does not currently support the SYRK operation "
2126 << "for the \"complex<float>\" dataype";
2127 return false;
2128 }
2129
DoBlasSyrk(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)2130 bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
2131 blas::Transpose trans, uint64 n, uint64 k,
2132 std::complex<double> alpha,
2133 const DeviceMemory<std::complex<double>> &a, int lda,
2134 std::complex<double> beta,
2135 DeviceMemory<std::complex<double>> *c, int ldc) {
2136 LOG(ERROR) << "rocBLAS does not currently support the SYRK operation "
2137 << "for the \"complex<double>\" dataype";
2138 return false;
2139 }
2140
DoBlasSyr2k(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)2141 bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
2142 blas::Transpose trans, uint64 n, uint64 k,
2143 float alpha, const DeviceMemory<float> &a, int lda,
2144 const DeviceMemory<float> &b, int ldb, float beta,
2145 DeviceMemory<float> *c, int ldc) {
2146 LOG(ERROR) << "rocBLAS does not currently support the SYR2K operation "
2147 << "for the \"float\" dataype";
2148 return false;
2149 }
2150
DoBlasSyr2k(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)2151 bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
2152 blas::Transpose trans, uint64 n, uint64 k,
2153 double alpha, const DeviceMemory<double> &a, int lda,
2154 const DeviceMemory<double> &b, int ldb, double beta,
2155 DeviceMemory<double> *c, int ldc) {
2156 LOG(ERROR) << "rocBLAS does not currently support the SYR2K operation "
2157 << "for the \"double\" dataype";
2158 return false;
2159 }
2160
DoBlasSyr2k(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)2161 bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
2162 blas::Transpose trans, uint64 n, uint64 k,
2163 std::complex<float> alpha,
2164 const DeviceMemory<std::complex<float>> &a, int lda,
2165 const DeviceMemory<std::complex<float>> &b, int ldb,
2166 std::complex<float> beta,
2167 DeviceMemory<std::complex<float>> *c, int ldc) {
2168 LOG(ERROR) << "rocBLAS does not currently support the SYR2K operation "
2169 << "for the \"complex<float>\" dataype";
2170 return false;
2171 }
2172
DoBlasSyr2k(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)2173 bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
2174 blas::Transpose trans, uint64 n, uint64 k,
2175 std::complex<double> alpha,
2176 const DeviceMemory<std::complex<double>> &a, int lda,
2177 const DeviceMemory<std::complex<double>> &b, int ldb,
2178 std::complex<double> beta,
2179 DeviceMemory<std::complex<double>> *c, int ldc) {
2180 LOG(ERROR) << "rocBLAS does not currently support the SYR2K operation "
2181 << "for the \"complex<double>\" dataype";
2182 return false;
2183 }
2184
DoBlasTrmm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * b,int ldb)2185 bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side,
2186 blas::UpperLower uplo, blas::Transpose transa,
2187 blas::Diagonal diag, uint64 m, uint64 n, float alpha,
2188 const DeviceMemory<float> &a, int lda,
2189 DeviceMemory<float> *b, int ldb) {
2190 LOG(ERROR) << "rocBLAS does not currently support the TRMM operation "
2191 << "for the \"float\" dataype";
2192 return false;
2193 }
2194
DoBlasTrmm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * b,int ldb)2195 bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side,
2196 blas::UpperLower uplo, blas::Transpose transa,
2197 blas::Diagonal diag, uint64 m, uint64 n, double alpha,
2198 const DeviceMemory<double> &a, int lda,
2199 DeviceMemory<double> *b, int ldb) {
2200 LOG(ERROR) << "rocBLAS does not currently support the TRMM operation "
2201 << "for the \"double\" dataype";
2202 return false;
2203 }
2204
DoBlasTrmm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * b,int ldb)2205 bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side,
2206 blas::UpperLower uplo, blas::Transpose transa,
2207 blas::Diagonal diag, uint64 m, uint64 n,
2208 std::complex<float> alpha,
2209 const DeviceMemory<std::complex<float>> &a, int lda,
2210 DeviceMemory<std::complex<float>> *b, int ldb) {
2211 LOG(ERROR) << "rocBLAS does not currently support the TRMM operation "
2212 << "for the \"complex<float>\" dataype";
2213 return false;
2214 }
2215
DoBlasTrmm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * b,int ldb)2216 bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side,
2217 blas::UpperLower uplo, blas::Transpose transa,
2218 blas::Diagonal diag, uint64 m, uint64 n,
2219 std::complex<double> alpha,
2220 const DeviceMemory<std::complex<double>> &a, int lda,
2221 DeviceMemory<std::complex<double>> *b, int ldb) {
2222 LOG(ERROR) << "rocBLAS does not currently support the TRMM operation "
2223 << "for the \"complex<double>\" dataype";
2224 return false;
2225 }
2226
DoBlasTrsm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * b,int ldb)2227 bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
2228 blas::UpperLower uplo, blas::Transpose transa,
2229 blas::Diagonal diag, uint64 m, uint64 n, float alpha,
2230 const DeviceMemory<float> &a, int lda,
2231 DeviceMemory<float> *b, int ldb) {
2232 return DoBlasInternal(
2233 wrap::rocblas_strsm, stream, true /* = pointer_mode_host */,
2234 ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
2235 ROCMBlasDiagonal(diag), m, n, &alpha, const_cast<float *>(GpuMemory(a)),
2236 lda, GpuMemoryMutable(b), ldb);
2237 }
2238
DoBlasTrsm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * b,int ldb)2239 bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
2240 blas::UpperLower uplo, blas::Transpose transa,
2241 blas::Diagonal diag, uint64 m, uint64 n, double alpha,
2242 const DeviceMemory<double> &a, int lda,
2243 DeviceMemory<double> *b, int ldb) {
2244 return DoBlasInternal(
2245 wrap::rocblas_dtrsm, stream, true /* = pointer_mode_host */,
2246 ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
2247 ROCMBlasDiagonal(diag), m, n, &alpha, const_cast<double *>(GpuMemory(a)),
2248 lda, GpuMemoryMutable(b), ldb);
2249 }
2250
DoBlasTrsm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * b,int ldb)2251 bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
2252 blas::UpperLower uplo, blas::Transpose transa,
2253 blas::Diagonal diag, uint64 m, uint64 n,
2254 std::complex<float> alpha,
2255 const DeviceMemory<std::complex<float>> &a, int lda,
2256 DeviceMemory<std::complex<float>> *b, int ldb) {
2257 LOG(ERROR) << "rocBLAS does not currently support the TRSM operation "
2258 << "for the \"complex<float>\" dataype";
2259 return false;
2260 }
2261
DoBlasTrsm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * b,int ldb)2262 bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
2263 blas::UpperLower uplo, blas::Transpose transa,
2264 blas::Diagonal diag, uint64 m, uint64 n,
2265 std::complex<double> alpha,
2266 const DeviceMemory<std::complex<double>> &a, int lda,
2267 DeviceMemory<std::complex<double>> *b, int ldb) {
2268 LOG(ERROR) << "rocBLAS does not currently support the TRSM operation "
2269 << "for the \"complex<double>\" dataype";
2270 return false;
2271 }
DoBlasGemmStridedBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,int64 stride_a,const DeviceMemory<Eigen::half> & b,int ldb,int64 stride_b,float beta,DeviceMemory<Eigen::half> * c,int ldc,int64 stride_c,int batch_count)2272 bool ROCMBlas::DoBlasGemmStridedBatched(
2273 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2274 uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
2275 int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb,
2276 int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc,
2277 int64 stride_c, int batch_count) {
2278 LOG(ERROR) << "rocBLAS does not currently support the "
2279 "DoBlasGemmStridedBatched operation "
2280 << "for the \"Eigen::half\" dataype";
2281 return false;
2282 }
DoBlasGemmStridedBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,int64 stride_a,const DeviceMemory<float> & b,int ldb,int64 stride_b,float beta,DeviceMemory<float> * c,int ldc,int64 stride_c,int batch_count)2283 bool ROCMBlas::DoBlasGemmStridedBatched(
2284 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2285 uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
2286 int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
2287 float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
2288 int batch_count) {
2289 LOG(ERROR) << "rocBLAS does not currently support the "
2290 "DoBlasGemmStridedBatched operation "
2291 << "for the \"float\" dataype";
2292 return false;
2293 }
DoBlasGemmStridedBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,int64 stride_a,const DeviceMemory<double> & b,int ldb,int64 stride_b,double beta,DeviceMemory<double> * c,int ldc,int64 stride_c,int batch_count)2294 bool ROCMBlas::DoBlasGemmStridedBatched(
2295 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2296 uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
2297 int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
2298 double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
2299 int batch_count) {
2300 LOG(ERROR) << "rocBLAS does not currently support the "
2301 "DoBlasGemmStridedBatched operation "
2302 << "for the \"double\" dataype";
2303 return false;
2304 }
DoBlasGemmStridedBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,int64 stride_a,const DeviceMemory<std::complex<float>> & b,int ldb,int64 stride_b,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc,int64 stride_c,int batch_count)2305 bool ROCMBlas::DoBlasGemmStridedBatched(
2306 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2307 uint64 n, uint64 k, std::complex<float> alpha,
2308 const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
2309 const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
2310 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
2311 int64 stride_c, int batch_count) {
2312 LOG(ERROR) << "rocBLAS does not currently support the "
2313 "DoBlasGemmStridedBatched operation "
2314 << "for the \"complex<float>\" dataype";
2315 return false;
2316 }
DoBlasGemmStridedBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,int64 stride_a,const DeviceMemory<std::complex<double>> & b,int ldb,int64 stride_b,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc,int64 stride_c,int batch_count)2317 bool ROCMBlas::DoBlasGemmStridedBatched(
2318 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2319 uint64 n, uint64 k, std::complex<double> alpha,
2320 const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
2321 const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
2322 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
2323 int64 stride_c, int batch_count) {
2324 LOG(ERROR) << "rocBLAS does not currently support the "
2325 "DoBlasGemmStridedBatched operation "
2326 << "for the \"complex<double>\" dataype";
2327 return false;
2328 }
2329 } // namespace gpu
2330
initialize_rocblas()2331 void initialize_rocblas() {
2332 auto rocBlasAlreadyRegistered = PluginRegistry::Instance()->HasFactory(
2333 rocm::kROCmPlatformId, PluginKind::kBlas, gpu::kRocBlasPlugin);
2334
2335 if (!rocBlasAlreadyRegistered) {
2336 port::Status status =
2337 PluginRegistry::Instance()
2338 ->RegisterFactory<PluginRegistry::BlasFactory>(
2339 rocm::kROCmPlatformId, gpu::kRocBlasPlugin, "rocBLAS",
2340 [](internal::StreamExecutorInterface *parent)
2341 -> blas::BlasSupport * {
2342 gpu::GpuExecutor *rocm_executor =
2343 dynamic_cast<gpu::GpuExecutor *>(parent);
2344 if (rocm_executor == nullptr) {
2345 LOG(ERROR)
2346 << "Attempting to initialize an instance of the "
2347 "rocBLAS "
2348 << "support library with a non-ROCM StreamExecutor";
2349 return nullptr;
2350 }
2351
2352 gpu::ROCMBlas *blas = new gpu::ROCMBlas(rocm_executor);
2353 if (!blas->Init()) {
2354 // Note: Init() will log a more specific error.
2355 delete blas;
2356 return nullptr;
2357 }
2358 return blas;
2359 });
2360
2361 if (!status.ok()) {
2362 LOG(ERROR) << "Unable to register rocBLAS factory: "
2363 << status.error_message();
2364 }
2365
2366 PluginRegistry::Instance()->SetDefaultFactory(
2367 rocm::kROCmPlatformId, PluginKind::kBlas, gpu::kRocBlasPlugin);
2368 }
2369 }
2370
2371 } // namespace stream_executor
2372
2373 REGISTER_MODULE_INITIALIZER(register_rocblas,
2374 { stream_executor::initialize_rocblas(); });
2375