• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Exposes the family of BLAS routines as pre-canned high performance calls for
17 // use in conjunction with the StreamExecutor abstraction.
18 //
19 // Note that this interface is optionally supported by platforms; see
20 // StreamExecutor::SupportsBlas() for details.
21 //
22 // This abstraction makes it simple to entrain BLAS operations on GPU data into
23 // a Stream -- users typically will not use this API directly, but will use the
24 // Stream builder methods to entrain these operations "under the hood". For
25 // example:
26 //
27 //  DeviceMemory<float> x = stream_exec->AllocateArray<float>(1024);
28 //  DeviceMemory<float> y = stream_exec->AllocateArray<float>(1024);
29 //  // ... populate x and y ...
30 //  Stream stream{stream_exec};
31 //  stream
32 //    .Init()
33 //    .ThenBlasAxpy(1024, 5.5, x, 1, &y, 1);
34 //  SE_CHECK_OK(stream.BlockHostUntilDone());
35 //
36 // By using stream operations in this manner the user can easily intermix custom
37 // kernel launches (via StreamExecutor::ThenLaunch()) with these pre-canned BLAS
38 // routines.
39 
40 #ifndef TENSORFLOW_STREAM_EXECUTOR_BLAS_H_
41 #define TENSORFLOW_STREAM_EXECUTOR_BLAS_H_
42 
43 #include <complex>
44 #include <vector>
45 
46 #include "tensorflow/stream_executor/dnn.h"  // For DataType, ToDataType
47 #include "tensorflow/stream_executor/lib/array_slice.h"
48 #include "tensorflow/stream_executor/lib/statusor.h"
49 #include "tensorflow/stream_executor/platform/port.h"
50 
51 namespace Eigen {
52 struct half;
53 }  // namespace Eigen
54 
55 namespace stream_executor {
56 
57 class Stream;
58 class ScratchAllocator;
59 
60 template <typename ElemT>
61 class DeviceMemory;
62 
63 template <typename ElemT>
64 class HostOrDeviceScalar;
65 
66 namespace blas {
67 
68 // Specifies whether the input matrix will be transposed or
69 // transposed+conjugated before any BLAS operations.
70 enum class Transpose { kNoTranspose, kTranspose, kConjugateTranspose };
71 
72 // Returns a name for t.
73 std::string TransposeString(Transpose t);
74 
75 // Specifies whether the upper or lower triangular part of a
76 // symmetric/Hermitian matrix is used.
77 enum class UpperLower { kUpper, kLower };
78 
79 // Returns a name for ul.
80 std::string UpperLowerString(UpperLower ul);
81 
82 // Specifies whether a matrix is unit triangular.
83 enum class Diagonal { kUnit, kNonUnit };
84 
85 // Returns a name for d.
86 std::string DiagonalString(Diagonal d);
87 
88 // Specifies whether a Hermitian matrix appears on the left or right in
89 // operation.
90 enum class Side { kLeft, kRight };
91 
92 // Returns a name for s.
93 std::string SideString(Side s);
94 
95 // Type with which intermediate computations of a blas routine are performed.
96 //
97 // Some blas calls can perform computations with a type that's different than
98 // the type of their inputs/outputs.  This lets you e.g. multiply two matrices
99 // of int8s using float32s to store the matmul's intermediate values.
100 enum class ComputationType {
101   kF16,         // 16-bit floating-point
102   kF32,         // 32-bit floating-point
103   kF64,         // 64-bit floating-point
104   kI32,         // 32-bit integer
105   kComplexF32,  // Complex number comprised of two f32s.
106   kComplexF64,  // Complex number comprised of two f64s.
107   // The below values are only supported for BlasLt routines (both real and
108   // complex). They use float32 for accumulation but round the input mantissas
109   // to a smaller number of bits.
110   kTF32AsF32,  // 32-bit floating-point with reduced (>=10-bit) mantissa
111   kBF16AsF32,  // 32-bit floating-point with reduced (7-bit) mantissa
112 };
113 
114 enum class Epilogue {
115   kDefault = 1,                   // No special postprocessing
116   kReLU = 2,                      // Apply ReLU func point-wise to the results
117   kBias = 4,                      // Add broadcasted bias vector to the results
118   kBiasThenReLU = kBias | kReLU,  // Apply bias and then ReLU transform
119 };
120 
121 // Converts a ComputationType to a string.
122 std::string ComputationTypeString(ComputationType ty);
123 
124 template <typename T>
125 struct ToComputationType;
126 template <>
127 struct ToComputationType<float> {
128   static constexpr ComputationType value = ComputationType::kF32;
129 };
130 template <>
131 struct ToComputationType<double> {
132   static constexpr ComputationType value = ComputationType::kF64;
133 };
134 template <>
135 struct ToComputationType<Eigen::half> {
136   static constexpr ComputationType value = ComputationType::kF16;
137 };
138 template <>
139 struct ToComputationType<Eigen::bfloat16> {
140   static constexpr ComputationType value = ComputationType::kBF16AsF32;
141 };
142 template <>
143 struct ToComputationType<tensorflow::int32> {
144   static constexpr ComputationType value = ComputationType::kI32;
145 };
146 template <>
147 struct ToComputationType<std::complex<float>> {
148   static constexpr ComputationType value = ComputationType::kComplexF32;
149 };
150 template <>
151 struct ToComputationType<std::complex<double>> {
152   static constexpr ComputationType value = ComputationType::kComplexF64;
153 };
154 
155 std::ostream &operator<<(std::ostream &os, ComputationType ty);
156 
157 using dnn::DataType;
158 using dnn::ToDataType;
159 
160 // Describes the type of pointers for the scaling factors alpha and beta in
161 // blaslt routines.
162 enum class PointerMode {
163   kHost,
164   kDevice,
165 };
166 
167 // Converts a ComputationType to a string.
168 std::string DataTypeString(DataType ty);
169 
170 std::ostream &operator<<(std::ostream &os, DataType ty);
171 
172 // Opaque identifier for an "algorithm" used by a blas routine.  This functions
173 // as a hint to the blas library.
174 typedef int64 AlgorithmType;
175 constexpr AlgorithmType kDefaultAlgorithm = -1;
176 constexpr AlgorithmType kDefaultBlasGemm = -2;
177 constexpr AlgorithmType kDefaultBlasGemv = -3;
178 constexpr AlgorithmType kNoAlgorithm = -4;
179 
180 // blas uses -1 to represent the default algorithm. This happens to match up
181 // with the CUBLAS_GEMM_DFALT constant, so cuda_blas.cc is using static_cast
182 // to convert from AlgorithmType to cublasGemmAlgo_t, and uses a static_assert
183 // to ensure that this assumption does not break.
184 // If another blas implementation uses a different value for the default
185 // algorithm, then it needs to convert kDefaultGemmAlgo to that value
186 // (e.g. via a function called ToWhateverGemmAlgo).
187 constexpr AlgorithmType kDefaultGemmAlgo = -1;
188 
189 // Describes the result of a performance experiment, usually timing the speed of
190 // a particular AlgorithmType.
191 //
192 // If the call we were benchmarking failed (a common occurrence; not all
193 // algorithms are valid for all calls), is_valid() will be false.
194 class ProfileResult {
195  public:
196   bool is_valid() const { return is_valid_; }
197   void set_is_valid(bool val) { is_valid_ = val; }
198   AlgorithmType algorithm() const { return algorithm_; }
199   void set_algorithm(AlgorithmType val) { algorithm_ = val; }
200   float elapsed_time_in_ms() const { return elapsed_time_in_ms_; }
201   void set_elapsed_time_in_ms(float val) { elapsed_time_in_ms_ = val; }
202 
203  private:
204   bool is_valid_ = false;
205   AlgorithmType algorithm_ = kDefaultAlgorithm;
206   float elapsed_time_in_ms_ = std::numeric_limits<float>::max();
207 };
208 
209 class AlgorithmConfig {
210  public:
211   AlgorithmConfig() : algorithm_(kDefaultAlgorithm) {}
212   explicit AlgorithmConfig(AlgorithmType algorithm) : algorithm_(algorithm) {}
213   AlgorithmType algorithm() const { return algorithm_; }
214   void set_algorithm(AlgorithmType val) { algorithm_ = val; }
215   bool operator==(const AlgorithmConfig &other) const {
216     return this->algorithm_ == other.algorithm_;
217   }
218   bool operator!=(const AlgorithmConfig &other) const {
219     return !(*this == other);
220   }
221   std::string ToString() const;
222 
223  private:
224   AlgorithmType algorithm_;
225 };
226 
227 struct IBlasLtMatmulPlan {
228   // Returns the data type of the A and B (input) matrices.
229   virtual DataType ab_type() const = 0;
230   // Returns the data type of the C (input/output) matrix.
231   virtual DataType c_type() const = 0;
232   virtual ~IBlasLtMatmulPlan() {}
233 };
234 
235 struct IBlasLtMatmulAlgorithm {
236   virtual ~IBlasLtMatmulAlgorithm() {}
237   // Returns the index of the algorithm within the list returned by
238   // GetBlasLtMatmulAlgorithms.
239   virtual AlgorithmType index() const = 0;
240   // Returns the workspace size required by the algorithm in bytes.
241   virtual size_t workspace_size() const = 0;
242 };
243 
244 // Parameters for the CreateBlasLtMatmulPlan method.
245 struct BlasLtMatmulPlanParams {
246   DataType ab_type;
247   DataType c_type;
248   ComputationType computation_type;
249   PointerMode pointer_mode;
250   Epilogue epilogue;
251   Transpose transa;
252   Transpose transb;
253   uint64 m;
254   uint64 n;
255   uint64 k;
256   int64 lda;
257   int64 ldb;
258   int64 ldc;
259   int batch_count = 1;
260   int64 stride_a = 0;
261   int64 stride_b = 0;
262   int64 stride_c = 0;
263 };
264 
265 // BLAS support interface -- this can be derived from a GPU executor when the
266 // underlying platform has an BLAS library implementation available. See
267 // StreamExecutor::AsBlas().
268 //
269 // Thread-hostile: CUDA associates a CUDA-context with a particular thread in
270 // the system. Any operation that a user attempts to perform by enqueueing BLAS
271 // operations on a thread not-associated with the CUDA-context has unknown
272 // behavior at the current time; see b/13176597
273 class BlasSupport {
274  public:
275   virtual ~BlasSupport() {}
276 
277   // Computes the sum of magnitudes of the vector elements.
278   // result <- |Re x(1)| + |Im x(1)| + |Re  x(2)| + |Im  x(2)|+ ... + |Re  x(n)|
279   // + |Im x(n)|.
280   // Note that Im x(i) = 0 for real types float/double.
281   virtual bool DoBlasAsum(Stream *stream, uint64 elem_count,
282                           const DeviceMemory<float> &x, int incx,
283                           DeviceMemory<float> *result) = 0;
284   virtual bool DoBlasAsum(Stream *stream, uint64 elem_count,
285                           const DeviceMemory<double> &x, int incx,
286                           DeviceMemory<double> *result) = 0;
287   virtual bool DoBlasAsum(Stream *stream, uint64 elem_count,
288                           const DeviceMemory<std::complex<float>> &x, int incx,
289                           DeviceMemory<float> *result) = 0;
290   virtual bool DoBlasAsum(Stream *stream, uint64 elem_count,
291                           const DeviceMemory<std::complex<double>> &x, int incx,
292                           DeviceMemory<double> *result) = 0;
293 
294   // Performs a BLAS y <- ax+y operation.
295   virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha,
296                           const DeviceMemory<float> &x, int incx,
297                           DeviceMemory<float> *y, int incy) = 0;
298   virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha,
299                           const DeviceMemory<double> &x, int incx,
300                           DeviceMemory<double> *y, int incy) = 0;
301   virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count,
302                           std::complex<float> alpha,
303                           const DeviceMemory<std::complex<float>> &x, int incx,
304                           DeviceMemory<std::complex<float>> *y, int incy) = 0;
305   virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count,
306                           std::complex<double> alpha,
307                           const DeviceMemory<std::complex<double>> &x, int incx,
308                           DeviceMemory<std::complex<double>> *y, int incy) = 0;
309 
310   // Copies vector to another vector: y <- x.
311   virtual bool DoBlasCopy(Stream *stream, uint64 elem_count,
312                           const DeviceMemory<float> &x, int incx,
313                           DeviceMemory<float> *y, int incy) = 0;
314   virtual bool DoBlasCopy(Stream *stream, uint64 elem_count,
315                           const DeviceMemory<double> &x, int incx,
316                           DeviceMemory<double> *y, int incy) = 0;
317   virtual bool DoBlasCopy(Stream *stream, uint64 elem_count,
318                           const DeviceMemory<std::complex<float>> &x, int incx,
319                           DeviceMemory<std::complex<float>> *y, int incy) = 0;
320   virtual bool DoBlasCopy(Stream *stream, uint64 elem_count,
321                           const DeviceMemory<std::complex<double>> &x, int incx,
322                           DeviceMemory<std::complex<double>> *y, int incy) = 0;
323 
324   // Performs a BLAS dot product result <- x . y.
325   virtual bool DoBlasDot(Stream *stream, uint64 elem_count,
326                          const DeviceMemory<float> &x, int incx,
327                          const DeviceMemory<float> &y, int incy,
328                          DeviceMemory<float> *result) = 0;
329   virtual bool DoBlasDot(Stream *stream, uint64 elem_count,
330                          const DeviceMemory<double> &x, int incx,
331                          const DeviceMemory<double> &y, int incy,
332                          DeviceMemory<double> *result) = 0;
333 
334   // Performs a BLAS dot product result <- conj(x) . y for complex types.
335   virtual bool DoBlasDotc(Stream *stream, uint64 elem_count,
336                           const DeviceMemory<std::complex<float>> &x, int incx,
337                           const DeviceMemory<std::complex<float>> &y, int incy,
338                           DeviceMemory<std::complex<float>> *result) = 0;
339   virtual bool DoBlasDotc(Stream *stream, uint64 elem_count,
340                           const DeviceMemory<std::complex<double>> &x, int incx,
341                           const DeviceMemory<std::complex<double>> &y, int incy,
342                           DeviceMemory<std::complex<double>> *result) = 0;
343 
344   // Performs a BLAS dot product result <- x . y for complex types. Note that
345   // x is unconjugated in this routine.
346   virtual bool DoBlasDotu(Stream *stream, uint64 elem_count,
347                           const DeviceMemory<std::complex<float>> &x, int incx,
348                           const DeviceMemory<std::complex<float>> &y, int incy,
349                           DeviceMemory<std::complex<float>> *result) = 0;
350   virtual bool DoBlasDotu(Stream *stream, uint64 elem_count,
351                           const DeviceMemory<std::complex<double>> &x, int incx,
352                           const DeviceMemory<std::complex<double>> &y, int incy,
353                           DeviceMemory<std::complex<double>> *result) = 0;
354 
355   // Computes the Euclidean norm of a vector: result <- ||x||.
356   // See the following link for more information of Euclidean norm:
357   // http://en.wikipedia.org/wiki/Norm_(mathematics)#Euclidean_norm
358   virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count,
359                           const DeviceMemory<float> &x, int incx,
360                           DeviceMemory<float> *result) = 0;
361   virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count,
362                           const DeviceMemory<double> &x, int incx,
363                           DeviceMemory<double> *result) = 0;
364   virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count,
365                           const DeviceMemory<std::complex<float>> &x, int incx,
366                           DeviceMemory<float> *result) = 0;
367   virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count,
368                           const DeviceMemory<std::complex<double>> &x, int incx,
369                           DeviceMemory<double> *result) = 0;
370 
371   // Performs rotation of points in the plane:
372   // x(i) = c*x(i) + s*y(i)
373   // y(i) = c*y(i) - s*x(i).
374   virtual bool DoBlasRot(Stream *stream, uint64 elem_count,
375                          DeviceMemory<float> *x, int incx,
376                          DeviceMemory<float> *y, int incy, float c,
377                          float s) = 0;
378   virtual bool DoBlasRot(Stream *stream, uint64 elem_count,
379                          DeviceMemory<double> *x, int incx,
380                          DeviceMemory<double> *y, int incy, double c,
381                          double s) = 0;
382   virtual bool DoBlasRot(Stream *stream, uint64 elem_count,
383                          DeviceMemory<std::complex<float>> *x, int incx,
384                          DeviceMemory<std::complex<float>> *y, int incy,
385                          float c, float s) = 0;
386   virtual bool DoBlasRot(Stream *stream, uint64 elem_count,
387                          DeviceMemory<std::complex<double>> *x, int incx,
388                          DeviceMemory<std::complex<double>> *y, int incy,
389                          double c, double s) = 0;
390 
391   // Computes the parameters for a Givens rotation.
392   // Given the Cartesian coordinates (a, b) of a point, these routines return
393   // the parameters c, s, r, and z associated with the Givens rotation. The
394   // parameters c and s define a unitary matrix such that:
395   //
396   //   |  c s |.| a | = | r |
397   //   | -s c | | b |   | 0 |
398   //
399   // The parameter z is defined such that if |a| > |b|, z is s; otherwise if
400   // c is not 0 z is 1/c; otherwise z is 1.
401   virtual bool DoBlasRotg(Stream *stream, DeviceMemory<float> *a,
402                           DeviceMemory<float> *b, DeviceMemory<float> *c,
403                           DeviceMemory<float> *s) = 0;
404   virtual bool DoBlasRotg(Stream *stream, DeviceMemory<double> *a,
405                           DeviceMemory<double> *b, DeviceMemory<double> *c,
406                           DeviceMemory<double> *s) = 0;
407   virtual bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<float>> *a,
408                           DeviceMemory<std::complex<float>> *b,
409                           DeviceMemory<float> *c,
410                           DeviceMemory<std::complex<float>> *s) = 0;
411   virtual bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<double>> *a,
412                           DeviceMemory<std::complex<double>> *b,
413                           DeviceMemory<double> *c,
414                           DeviceMemory<std::complex<double>> *s) = 0;
415 
416   // Performs modified Givens rotation of points in the plane.
417   // Given two vectors x and y, each vector element of these vectors is replaced
418   // as follows:
419   //
420   //   | x(i) | =  H | x(i) |
421   //   | y(i) |      | y(i) |
422   //
423   // for i=1 to n, where H is a modified Givens transformation matrix whose
424   // values are stored in the param[1] through param[4] array.
425   // For more information please Google this routine.
426   virtual bool DoBlasRotm(Stream *stream, uint64 elem_count,
427                           DeviceMemory<float> *x, int incx,
428                           DeviceMemory<float> *y, int incy,
429                           const DeviceMemory<float> &param) = 0;
430   virtual bool DoBlasRotm(Stream *stream, uint64 elem_count,
431                           DeviceMemory<double> *x, int incx,
432                           DeviceMemory<double> *y, int incy,
433                           const DeviceMemory<double> &param) = 0;
434 
435   // Computes the parameters for a modified Givens rotation.
436   // Given Cartesian coordinates (x1, y1) of an input vector, these routines
437   // compute the components of a modified Givens transformation matrix H that
438   // zeros the y-component of the resulting vector:
439   //
440   //   | x1 | =  H | x1 * sqrt(d1) |
441   //   |  0 |      | y1 * sqrt(d1) |
442   //
443   // For more information please Google this routine.
444   virtual bool DoBlasRotmg(Stream *stream, DeviceMemory<float> *d1,
445                            DeviceMemory<float> *d2, DeviceMemory<float> *x1,
446                            const DeviceMemory<float> &y1,
447                            DeviceMemory<float> *param) = 0;
448   virtual bool DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1,
449                            DeviceMemory<double> *d2, DeviceMemory<double> *x1,
450                            const DeviceMemory<double> &y1,
451                            DeviceMemory<double> *param) = 0;
452 
453   // Computes the product of a vector by a scalar: x <- a*x.
454   virtual bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
455                           DeviceMemory<float> *x, int incx) = 0;
456   virtual bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
457                           DeviceMemory<double> *x, int incx) = 0;
458   virtual bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
459                           DeviceMemory<std::complex<float>> *x, int incx) = 0;
460   virtual bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
461                           DeviceMemory<std::complex<double>> *x, int incx) = 0;
462   virtual bool DoBlasScal(Stream *stream, uint64 elem_count,
463                           std::complex<float> alpha,
464                           DeviceMemory<std::complex<float>> *x, int incx) = 0;
465   virtual bool DoBlasScal(Stream *stream, uint64 elem_count,
466                           std::complex<double> alpha,
467                           DeviceMemory<std::complex<double>> *x, int incx) = 0;
468 
469   // Swaps a vector with another vector.
470   virtual bool DoBlasSwap(Stream *stream, uint64 elem_count,
471                           DeviceMemory<float> *x, int incx,
472                           DeviceMemory<float> *y, int incy) = 0;
473   virtual bool DoBlasSwap(Stream *stream, uint64 elem_count,
474                           DeviceMemory<double> *x, int incx,
475                           DeviceMemory<double> *y, int incy) = 0;
476   virtual bool DoBlasSwap(Stream *stream, uint64 elem_count,
477                           DeviceMemory<std::complex<float>> *x, int incx,
478                           DeviceMemory<std::complex<float>> *y, int incy) = 0;
479   virtual bool DoBlasSwap(Stream *stream, uint64 elem_count,
480                           DeviceMemory<std::complex<double>> *x, int incx,
481                           DeviceMemory<std::complex<double>> *y, int incy) = 0;
482 
483   // Finds the index of the element with maximum absolute value.
484   virtual bool DoBlasIamax(Stream *stream, uint64 elem_count,
485                            const DeviceMemory<float> &x, int incx,
486                            DeviceMemory<int> *result) = 0;
487   virtual bool DoBlasIamax(Stream *stream, uint64 elem_count,
488                            const DeviceMemory<double> &x, int incx,
489                            DeviceMemory<int> *result) = 0;
490   virtual bool DoBlasIamax(Stream *stream, uint64 elem_count,
491                            const DeviceMemory<std::complex<float>> &x, int incx,
492                            DeviceMemory<int> *result) = 0;
493   virtual bool DoBlasIamax(Stream *stream, uint64 elem_count,
494                            const DeviceMemory<std::complex<double>> &x,
495                            int incx, DeviceMemory<int> *result) = 0;
496 
497   // Finds the index of the element with minimum absolute value.
498   virtual bool DoBlasIamin(Stream *stream, uint64 elem_count,
499                            const DeviceMemory<float> &x, int incx,
500                            DeviceMemory<int> *result) = 0;
501   virtual bool DoBlasIamin(Stream *stream, uint64 elem_count,
502                            const DeviceMemory<double> &x, int incx,
503                            DeviceMemory<int> *result) = 0;
504   virtual bool DoBlasIamin(Stream *stream, uint64 elem_count,
505                            const DeviceMemory<std::complex<float>> &x, int incx,
506                            DeviceMemory<int> *result) = 0;
507   virtual bool DoBlasIamin(Stream *stream, uint64 elem_count,
508                            const DeviceMemory<std::complex<double>> &x,
509                            int incx, DeviceMemory<int> *result) = 0;
510 
511   // Computes a matrix-vector product using a general band matrix:
512   //
513   //     y <- alpha * a * x + beta * y,
514   // or
515   //     y <- alpha * a' * x + beta * y,
516   // or
517   //     y <- alpha * conj(a') * x + beta * y,
518   //
519   // alpha and beta are scalars; a is an m-by-n general band matrix, with kl
520   // sub-diagonals and ku super-diagonals; x is a vector with
521   // n(trans==kNoTranspose)/m(otherwise) elements;
522   // y is a vector with m(trans==kNoTranspose)/n(otherwise) elements.
523   virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
524                           uint64 n, uint64 kl, uint64 ku, float alpha,
525                           const DeviceMemory<float> &a, int lda,
526                           const DeviceMemory<float> &x, int incx, float beta,
527                           DeviceMemory<float> *y, int incy) = 0;
528   virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
529                           uint64 n, uint64 kl, uint64 ku, double alpha,
530                           const DeviceMemory<double> &a, int lda,
531                           const DeviceMemory<double> &x, int incx, double beta,
532                           DeviceMemory<double> *y, int incy) = 0;
533   virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
534                           uint64 n, uint64 kl, uint64 ku,
535                           std::complex<float> alpha,
536                           const DeviceMemory<std::complex<float>> &a, int lda,
537                           const DeviceMemory<std::complex<float>> &x, int incx,
538                           std::complex<float> beta,
539                           DeviceMemory<std::complex<float>> *y, int incy) = 0;
540   virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
541                           uint64 n, uint64 kl, uint64 ku,
542                           std::complex<double> alpha,
543                           const DeviceMemory<std::complex<double>> &a, int lda,
544                           const DeviceMemory<std::complex<double>> &x, int incx,
545                           std::complex<double> beta,
546                           DeviceMemory<std::complex<double>> *y, int incy) = 0;
547 
548   // Computes a matrix-vector product using a general matrix.
549   //
550   //     y <- alpha * a * x + beta * y,
551   // or
552   //     y <- alpha * a' * x + beta * y,
553   // or
554   //     y <- alpha * conj(a') * x + beta * y,
555   //
556   // alpha and beta are scalars; a is an m-by-n general matrix; x is a vector
557   // with n(trans==kNoTranspose)/m(otherwise) elements;
558   // y is a vector with m(trans==kNoTranspose)/n(otherwise) elements.
559   virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
560                           uint64 n, float alpha, const DeviceMemory<float> &a,
561                           int lda, const DeviceMemory<float> &x, int incx,
562                           float beta, DeviceMemory<float> *y, int incy) = 0;
563   virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
564                           uint64 n, double alpha, const DeviceMemory<double> &a,
565                           int lda, const DeviceMemory<double> &x, int incx,
566                           double beta, DeviceMemory<double> *y, int incy) = 0;
567   virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
568                           uint64 n, std::complex<float> alpha,
569                           const DeviceMemory<std::complex<float>> &a, int lda,
570                           const DeviceMemory<std::complex<float>> &x, int incx,
571                           std::complex<float> beta,
572                           DeviceMemory<std::complex<float>> *y, int incy) = 0;
573   virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
574                           uint64 n, std::complex<double> alpha,
575                           const DeviceMemory<std::complex<double>> &a, int lda,
576                           const DeviceMemory<std::complex<double>> &x, int incx,
577                           std::complex<double> beta,
578                           DeviceMemory<std::complex<double>> *y, int incy) = 0;
579 
580   virtual bool DoBlasGemvWithProfiling(
581       Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha,
582       const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
583       int incx, float beta, DeviceMemory<float> *y, int incy,
584       ProfileResult *output_profile_result) = 0;
585   virtual bool DoBlasGemvWithProfiling(
586       Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha,
587       const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
588       int incx, double beta, DeviceMemory<double> *y, int incy,
589       ProfileResult *output_profile_result) = 0;
590   virtual bool DoBlasGemvWithProfiling(
591       Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
592       std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a,
593       int lda, const DeviceMemory<std::complex<float>> &x, int incx,
594       std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
595       ProfileResult *output_profile_result) = 0;
596   virtual bool DoBlasGemvWithProfiling(
597       Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
598       std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a,
599       int lda, const DeviceMemory<std::complex<double>> &x, int incx,
600       std::complex<double> beta, DeviceMemory<std::complex<double>> *y,
601       int incy, ProfileResult *output_profile_result) = 0;
602 
603   // Performs a rank-1 update of a general matrix.
604   //
605   //     a <- alpha * x * y' + a,
606   //
607   // alpha is a scalar; x is an m-element vector; y is an n-element vector; a is
608   // an m-by-n general matrix.
609   virtual bool DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha,
610                          const DeviceMemory<float> &x, int incx,
611                          const DeviceMemory<float> &y, int incy,
612                          DeviceMemory<float> *a, int lda) = 0;
613   virtual bool DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha,
614                          const DeviceMemory<double> &x, int incx,
615                          const DeviceMemory<double> &y, int incy,
616                          DeviceMemory<double> *a, int lda) = 0;
617 
618   // Performs a rank-1 update (conjugated) of a general matrix.
619   //
620   //     a <- alpha * x * conj(y') + a,
621   //
622   // alpha is a scalar; x is an m-element vector; y is an n-element vector; a is
623   // an m-by-n general matrix.
624   virtual bool DoBlasGerc(Stream *stream, uint64 m, uint64 n,
625                           std::complex<float> alpha,
626                           const DeviceMemory<std::complex<float>> &x, int incx,
627                           const DeviceMemory<std::complex<float>> &y, int incy,
628                           DeviceMemory<std::complex<float>> *a, int lda) = 0;
629   virtual bool DoBlasGerc(Stream *stream, uint64 m, uint64 n,
630                           std::complex<double> alpha,
631                           const DeviceMemory<std::complex<double>> &x, int incx,
632                           const DeviceMemory<std::complex<double>> &y, int incy,
633                           DeviceMemory<std::complex<double>> *a, int lda) = 0;
634 
635   // Performs a rank-1 update (unconjugated) of a general matrix.
636   //
637   //     a <- alpha * x * y' + a,
638   //
639   // alpha is a scalar; x is an m-element vector; y is an n-element vector; a is
640   // an m-by-n general matrix.
641   virtual bool DoBlasGeru(Stream *stream, uint64 m, uint64 n,
642                           std::complex<float> alpha,
643                           const DeviceMemory<std::complex<float>> &x, int incx,
644                           const DeviceMemory<std::complex<float>> &y, int incy,
645                           DeviceMemory<std::complex<float>> *a, int lda) = 0;
646   virtual bool DoBlasGeru(Stream *stream, uint64 m, uint64 n,
647                           std::complex<double> alpha,
648                           const DeviceMemory<std::complex<double>> &x, int incx,
649                           const DeviceMemory<std::complex<double>> &y, int incy,
650                           DeviceMemory<std::complex<double>> *a, int lda) = 0;
651 
652   // Computes a matrix-vector product using a Hermitian band matrix.
653   //
654   //     y <- alpha * a * x + beta * y,
655   //
656   // alpha and beta are scalars; a is an n-by-n Hermitian band matrix, with k
657   // super-diagonals; x and y are n-element vectors.
658   virtual bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
659                           uint64 k, std::complex<float> alpha,
660                           const DeviceMemory<std::complex<float>> &a, int lda,
661                           const DeviceMemory<std::complex<float>> &x, int incx,
662                           std::complex<float> beta,
663                           DeviceMemory<std::complex<float>> *y, int incy) = 0;
664   virtual bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
665                           uint64 k, std::complex<double> alpha,
666                           const DeviceMemory<std::complex<double>> &a, int lda,
667                           const DeviceMemory<std::complex<double>> &x, int incx,
668                           std::complex<double> beta,
669                           DeviceMemory<std::complex<double>> *y, int incy) = 0;
670 
671   // Computes a matrix-vector product using a Hermitian matrix.
672   //
673   //     y <- alpha * a * x + beta * y,
674   //
675   // alpha and beta are scalars; a is an n-by-n Hermitian matrix; x and y are
676   // n-element vectors.
677   virtual bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
678                           std::complex<float> alpha,
679                           const DeviceMemory<std::complex<float>> &a, int lda,
680                           const DeviceMemory<std::complex<float>> &x, int incx,
681                           std::complex<float> beta,
682                           DeviceMemory<std::complex<float>> *y, int incy) = 0;
683   virtual bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
684                           std::complex<double> alpha,
685                           const DeviceMemory<std::complex<double>> &a, int lda,
686                           const DeviceMemory<std::complex<double>> &x, int incx,
687                           std::complex<double> beta,
688                           DeviceMemory<std::complex<double>> *y, int incy) = 0;
689 
690   // Performs a rank-1 update of a Hermitian matrix.
691   //
692   //     a <- alpha * x * conj(x') + a,
693   //
694   // alpha is a scalar; x is an n-element vector; a is an n-by-n Hermitian
695   // matrix.
696   virtual bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
697                          float alpha,
698                          const DeviceMemory<std::complex<float>> &x, int incx,
699                          DeviceMemory<std::complex<float>> *a, int lda) = 0;
700   virtual bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
701                          double alpha,
702                          const DeviceMemory<std::complex<double>> &x, int incx,
703                          DeviceMemory<std::complex<double>> *a, int lda) = 0;
704 
705   // Performs a rank-2 update of a Hermitian matrix.
706   //
707   //     a <- alpha * x * conj(x') + conj(alpha) * y * conj(x') + a,
708   //
709   // alpha is a scalar; x and y are n-element vectors; a is an n-by-n Hermitian
710   // matrix.
711   virtual bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
712                           std::complex<float> alpha,
713                           const DeviceMemory<std::complex<float>> &x, int incx,
714                           const DeviceMemory<std::complex<float>> &y, int incy,
715                           DeviceMemory<std::complex<float>> *a, int lda) = 0;
716   virtual bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
717                           std::complex<double> alpha,
718                           const DeviceMemory<std::complex<double>> &x, int incx,
719                           const DeviceMemory<std::complex<double>> &y, int incy,
720                           DeviceMemory<std::complex<double>> *a, int lda) = 0;
721 
722   // Computes a matrix-vector product using a Hermitian packed matrix.
723   //
724   //     y <- alpha * a * x + beta * y,
725   //
726   // alpha and beta are scalars; a is an n-by-n Hermitian matrix, supplied in
727   // packed form; x and y are n-element vectors.
728   virtual bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
729                           std::complex<float> alpha,
730                           const DeviceMemory<std::complex<float>> &ap,
731                           const DeviceMemory<std::complex<float>> &x, int incx,
732                           std::complex<float> beta,
733                           DeviceMemory<std::complex<float>> *y, int incy) = 0;
734   virtual bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
735                           std::complex<double> alpha,
736                           const DeviceMemory<std::complex<double>> &ap,
737                           const DeviceMemory<std::complex<double>> &x, int incx,
738                           std::complex<double> beta,
739                           DeviceMemory<std::complex<double>> *y, int incy) = 0;
740 
741   // Performs a rank-1 update of a Hermitian packed matrix.
742   //
743   //     a <- alpha * x * conj(x') + a,
744   //
745   // alpha is a scalar; x is an n-element vector; a is an n-by-n Hermitian
746   // matrix, supplied in packed form.
747   virtual bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
748                          float alpha,
749                          const DeviceMemory<std::complex<float>> &x, int incx,
750                          DeviceMemory<std::complex<float>> *ap) = 0;
751   virtual bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
752                          double alpha,
753                          const DeviceMemory<std::complex<double>> &x, int incx,
754                          DeviceMemory<std::complex<double>> *ap) = 0;
755 
756   // Performs a rank-2 update of a Hermitian packed matrix.
757   //
758   //     a <- alpha * x * conj(x') + conj(alpha) * y * conj(x') + a,
759   //
760   // alpha is a scalar; x and y are n-element vectors; a is an n-by-n Hermitian
761   // matrix, supplied in packed form.
762   virtual bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
763                           std::complex<float> alpha,
764                           const DeviceMemory<std::complex<float>> &x, int incx,
765                           const DeviceMemory<std::complex<float>> &y, int incy,
766                           DeviceMemory<std::complex<float>> *ap) = 0;
767   virtual bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
768                           std::complex<double> alpha,
769                           const DeviceMemory<std::complex<double>> &x, int incx,
770                           const DeviceMemory<std::complex<double>> &y, int incy,
771                           DeviceMemory<std::complex<double>> *ap) = 0;
772 
773   // Computes a matrix-vector product using a symmetric band matrix.
774   //
775   //     y <- alpha * a * x + beta * y,
776   //
777   // alpha and beta are scalars; a is an n-by-n symmetric band matrix, with k
778   // super-diagonals; x and y are n-element vectors.
779   virtual bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
780                           uint64 k, float alpha, const DeviceMemory<float> &a,
781                           int lda, const DeviceMemory<float> &x, int incx,
782                           float beta, DeviceMemory<float> *y, int incy) = 0;
783   virtual bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
784                           uint64 k, double alpha, const DeviceMemory<double> &a,
785                           int lda, const DeviceMemory<double> &x, int incx,
786                           double beta, DeviceMemory<double> *y, int incy) = 0;
787 
788   // Computes a matrix-vector product using a symmetric packed matrix.
789   //
790   //     y <- alpha * a * x + beta * y,
791   //
792   // alpha and beta are scalars; a is an n-by-n symmetric matrix, supplied in
793   // packed form; x and y are n-element vectors.
794   virtual bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
795                           float alpha, const DeviceMemory<float> &ap,
796                           const DeviceMemory<float> &x, int incx, float beta,
797                           DeviceMemory<float> *y, int incy) = 0;
798   virtual bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
799                           double alpha, const DeviceMemory<double> &ap,
800                           const DeviceMemory<double> &x, int incx, double beta,
801                           DeviceMemory<double> *y, int incy) = 0;
802 
803   // Performs a rank-1 update of a symmetric packed matrix.
804   //
805   //     a <- alpha * x * x' + a,
806   //
807   // alpha is a scalar; x is an n-element vector; a is an n-by-n symmetric
808   // matrix, supplied in packed form.
809   virtual bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
810                          float alpha, const DeviceMemory<float> &x, int incx,
811                          DeviceMemory<float> *ap) = 0;
812   virtual bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
813                          double alpha, const DeviceMemory<double> &x, int incx,
814                          DeviceMemory<double> *ap) = 0;
815 
816   // Performs a rank-2 update of a symmetric packed matrix.
817   //
818   //     a <- alpha * x * x' + alpha * y * x' + a,
819   //
820   // alpha is a scalar; x and y are n-element vectors; a is an n-by-n symmetric
821   // matrix, supplied in packed form.
822   virtual bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
823                           float alpha, const DeviceMemory<float> &x, int incx,
824                           const DeviceMemory<float> &y, int incy,
825                           DeviceMemory<float> *ap) = 0;
826   virtual bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
827                           double alpha, const DeviceMemory<double> &x, int incx,
828                           const DeviceMemory<double> &y, int incy,
829                           DeviceMemory<double> *ap) = 0;
830 
831   // Computes a matrix-vector product for a symmetric matrix.
832   //
833   //     y <- alpha * a * x + beta * y,
834   //
835   // alpha and beta are scalars; a is an n-by-n symmetric matrix; x and y are
836   // n-element vectors.
837   virtual bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
838                           float alpha, const DeviceMemory<float> &a, int lda,
839                           const DeviceMemory<float> &x, int incx, float beta,
840                           DeviceMemory<float> *y, int incy) = 0;
841   virtual bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
842                           double alpha, const DeviceMemory<double> &a, int lda,
843                           const DeviceMemory<double> &x, int incx, double beta,
844                           DeviceMemory<double> *y, int incy) = 0;
845 
846   // Performs a rank-1 update of a symmetric matrix.
847   //
848   //     a <- alpha * x * x' + a,
849   //
850   // alpha is a scalar; x is an n-element vector; a is an n-by-n symmetric
851   // matrix.
852   virtual bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
853                          float alpha, const DeviceMemory<float> &x, int incx,
854                          DeviceMemory<float> *a, int lda) = 0;
855   virtual bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
856                          double alpha, const DeviceMemory<double> &x, int incx,
857                          DeviceMemory<double> *a, int lda) = 0;
858 
859   // Performs a rank-2 update of symmetric matrix.
860   //
861   //     a <- alpha * x * x' + alpha * y * x' + a,
862   //
863   // alpha is a scalar; x and y are n-element vectors; a is an n-by-n symmetric
864   // matrix.
865   virtual bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
866                           float alpha, const DeviceMemory<float> &x, int incx,
867                           const DeviceMemory<float> &y, int incy,
868                           DeviceMemory<float> *a, int lda) = 0;
869   virtual bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
870                           double alpha, const DeviceMemory<double> &x, int incx,
871                           const DeviceMemory<double> &y, int incy,
872                           DeviceMemory<double> *a, int lda) = 0;
873 
874   // Computes a matrix-vector product using a triangular band matrix.
875   //
876   //     x <- a * x,
877   // or
878   //     x <- a' * x,
879   // or
880   //     x <- conj(a') * x,
881   //
882   // a is an n-by-n unit, or non-unit, upper or lower triangular band matrix,
883   // with k+1 diagonals; x is a n-element vector.
884   virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
885                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
886                           uint64 k, const DeviceMemory<float> &a, int lda,
887                           DeviceMemory<float> *x, int incx) = 0;
888   virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
889                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
890                           uint64 k, const DeviceMemory<double> &a, int lda,
891                           DeviceMemory<double> *x, int incx) = 0;
892   virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
893                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
894                           uint64 k, const DeviceMemory<std::complex<float>> &a,
895                           int lda, DeviceMemory<std::complex<float>> *x,
896                           int incx) = 0;
897   virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
898                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
899                           uint64 k, const DeviceMemory<std::complex<double>> &a,
900                           int lda, DeviceMemory<std::complex<double>> *x,
901                           int incx) = 0;
902 
903   // Solves a system of linear equations whose coefficients are in a triangular
904   // band matrix as below:
905   //
906   //     a * x = b,
907   // or
908   //     a' * x = b,
909   // or
910   //     conj(a') * x = b,
911   //
912   // b and x are n-element vectors; a is an n-by-n unit, or non-unit, upper or
913   // lower triangular band matrix, with k+1 diagonals.
914   virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
915                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
916                           uint64 k, const DeviceMemory<float> &a, int lda,
917                           DeviceMemory<float> *x, int incx) = 0;
918   virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
919                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
920                           uint64 k, const DeviceMemory<double> &a, int lda,
921                           DeviceMemory<double> *x, int incx) = 0;
922   virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
923                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
924                           uint64 k, const DeviceMemory<std::complex<float>> &a,
925                           int lda, DeviceMemory<std::complex<float>> *x,
926                           int incx) = 0;
927   virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
928                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
929                           uint64 k, const DeviceMemory<std::complex<double>> &a,
930                           int lda, DeviceMemory<std::complex<double>> *x,
931                           int incx) = 0;
932 
933   // Computes a matrix-vector product using a triangular packed matrix.
934   //
935   //     x <- a * x,
936   // or
937   //     x <- a' * x,
938   // or
939   //     x <- conj(a') * x,
940   //
941   // a is an n-by-n unit, or non-unit, upper or lower triangular matrix,
942   // supplied in packed form; x is a n-element vector.
943   virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
944                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
945                           const DeviceMemory<float> &ap, DeviceMemory<float> *x,
946                           int incx) = 0;
947   virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
948                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
949                           const DeviceMemory<double> &ap,
950                           DeviceMemory<double> *x, int incx) = 0;
951   virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
952                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
953                           const DeviceMemory<std::complex<float>> &ap,
954                           DeviceMemory<std::complex<float>> *x, int incx) = 0;
955   virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
956                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
957                           const DeviceMemory<std::complex<double>> &ap,
958                           DeviceMemory<std::complex<double>> *x, int incx) = 0;
959 
960   // Solves a system of linear equations whose coefficients are in a triangular
961   // packed matrix as below:
962   //
963   //     a * x = b,
964   // or
965   //     a' * x = b,
966   // or
967   //     conj(a') * x = b,
968   //
969   // b and x are n-element vectors; a is an n-by-n unit, or non-unit, upper or
970   // lower triangular matrix, supplied in packed form.
971   virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
972                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
973                           const DeviceMemory<float> &ap, DeviceMemory<float> *x,
974                           int incx) = 0;
975   virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
976                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
977                           const DeviceMemory<double> &ap,
978                           DeviceMemory<double> *x, int incx) = 0;
979   virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
980                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
981                           const DeviceMemory<std::complex<float>> &ap,
982                           DeviceMemory<std::complex<float>> *x, int incx) = 0;
983   virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
984                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
985                           const DeviceMemory<std::complex<double>> &ap,
986                           DeviceMemory<std::complex<double>> *x, int incx) = 0;
987 
988   // Computes a matrix-vector product using a triangular matrix.
989   //
990   //     x <- a * x,
991   // or
992   //     x <- a' * x,
993   // or
994   //     x <- conj(a') * x,
995   //
996   // a is an n-by-n unit, or non-unit, upper or lower triangular matrix; x is a
997   // n-element vector.
998   virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
999                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1000                           const DeviceMemory<float> &a, int lda,
1001                           DeviceMemory<float> *x, int incx) = 0;
1002   virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
1003                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1004                           const DeviceMemory<double> &a, int lda,
1005                           DeviceMemory<double> *x, int incx) = 0;
1006   virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
1007                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1008                           const DeviceMemory<std::complex<float>> &a, int lda,
1009                           DeviceMemory<std::complex<float>> *x, int incx) = 0;
1010   virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
1011                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1012                           const DeviceMemory<std::complex<double>> &a, int lda,
1013                           DeviceMemory<std::complex<double>> *x, int incx) = 0;
1014 
1015   // Solves a system of linear equations whose coefficients are in a triangular
1016   // matrix as below:
1017   //
1018   //     a * x = b,
1019   // or
1020   //     a' * x = b,
1021   // or
1022   //     conj(a') * x = b,
1023   //
1024   // b and x are n-element vectors; a is an n-by-n unit, or non-unit, upper or
1025   // lower triangular matrix.
1026   virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1027                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1028                           const DeviceMemory<float> &a, int lda,
1029                           DeviceMemory<float> *x, int incx) = 0;
1030   virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1031                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1032                           const DeviceMemory<double> &a, int lda,
1033                           DeviceMemory<double> *x, int incx) = 0;
1034   virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1035                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1036                           const DeviceMemory<std::complex<float>> &a, int lda,
1037                           DeviceMemory<std::complex<float>> *x, int incx) = 0;
1038   virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1039                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1040                           const DeviceMemory<std::complex<double>> &a, int lda,
1041                           DeviceMemory<std::complex<double>> *x, int incx) = 0;
1042 
1043   // Computes a matrix-matrix product with general matrices:
1044   //
1045   //     c <- alpha * op(a) * op(b) + beta * c,
1046   //
1047   // op(X) is one of op(X) = X, or op(X) = X', or op(X) = conj(X'); alpha and
1048   // beta are scalars; a, b, and c are matrices; op(a) is an m-by-k matrix;
1049   // op(b) is a k-by-n matrix; c is an m-by-n matrix.
1050   //
1051   // Note: The half interface uses float precision internally; the version
1052   // that uses half precision internally is not yet supported. There is no
1053   // batched version of the half-precision interface.
1054   //
1055   // Alpha/beta type matches `dtype`, unless `dtype` is `Eigen::half`, in that
1056   // case the expected alpha/beta type is `float`.
1057   virtual port::Status DoBlasGemm(Stream *stream, blas::Transpose transa,
1058                                   blas::Transpose transb, uint64 m, uint64 n,
1059                                   uint64 k, DataType dtype, const void *alpha,
1060                                   const DeviceMemoryBase &a, int lda,
1061                                   const DeviceMemoryBase &b, int ldb,
1062                                   const void *beta, DeviceMemoryBase *c,
1063                                   int ldc) = 0;
1064 
1065   virtual bool DoBlasGemmWithProfiling(
1066       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1067       uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
1068       int lda, const DeviceMemory<Eigen::half> &b, int ldb, float beta,
1069       DeviceMemory<Eigen::half> *c, int ldc,
1070       ProfileResult *output_profile_result) = 0;
1071   virtual bool DoBlasGemmWithProfiling(
1072       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1073       uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
1074       const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
1075       int ldc, ProfileResult *output_profile_result) = 0;
1076   virtual bool DoBlasGemmWithProfiling(
1077       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1078       uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
1079       const DeviceMemory<double> &b, int ldb, double beta,
1080       DeviceMemory<double> *c, int ldc,
1081       ProfileResult *output_profile_result) = 0;
1082   virtual bool DoBlasGemmWithProfiling(
1083       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1084       uint64 n, uint64 k, std::complex<float> alpha,
1085       const DeviceMemory<std::complex<float>> &a, int lda,
1086       const DeviceMemory<std::complex<float>> &b, int ldb,
1087       std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
1088       ProfileResult *output_profile_result) = 0;
1089   virtual bool DoBlasGemmWithProfiling(
1090       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1091       uint64 n, uint64 k, std::complex<double> alpha,
1092       const DeviceMemory<std::complex<double>> &a, int lda,
1093       const DeviceMemory<std::complex<double>> &b, int ldb,
1094       std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
1095       ProfileResult *output_profile_result) = 0;
1096 
1097   // Gets a list of supported algorithms for DoBlasGemmWithAlgorithm.
1098   virtual bool GetBlasGemmAlgorithms(
1099       std::vector<AlgorithmType> *out_algorithms) = 0;
1100 
1101   // Like DoBlasGemm, but accepts an algorithm and an compute type.
1102   //
1103   // The compute type lets you say (e.g.) that the inputs and outputs are
1104   // Eigen::halfs, but you want the internal computations to be done with
1105   // float32 precision.
1106   //
1107   // If output_profile_result is not null, a failure here does not put the
1108   // stream in a failure state.  Instead, success/failure is indicated by
1109   // output_profile_result->is_valid().  This lets you use this function for
1110   // choosing the best algorithm among many (some of which may fail) without
1111   // creating a new Stream for each attempt.
1112   virtual port::Status DoBlasGemmWithAlgorithm(
1113       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1114       uint64 n, uint64 k, const void *alpha, const DeviceMemoryBase &a,
1115       DataType type_a, int lda, const DeviceMemoryBase &b, DataType type_b,
1116       int ldb, const void *beta, DeviceMemoryBase *c, DataType type_c, int ldc,
1117       ComputationType computation_type, AlgorithmType algorithm,
1118       ProfileResult *output_profile_result) = 0;
1119 
1120   virtual port::Status DoBlasGemmStridedBatchedWithAlgorithm(
1121       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1122       uint64 n, uint64 k, const void *alpha, const DeviceMemoryBase &a,
1123       DataType type_a, int lda, int64_t stride_a, const DeviceMemoryBase &b,
1124       DataType type_b, int ldb, int64_t stride_b, const void *beta,
1125       DeviceMemoryBase *c, DataType type_c, int ldc, int64_t stride_c,
1126       int batch_count, ComputationType computation_type,
1127       AlgorithmType algorithm, ProfileResult *output_profile_result) = 0;
1128 
1129   // Computes a batch of matrix-matrix product with general matrices.
1130   // This is a batched version of DoBlasGemm.
1131   // The batched GEMM computes matrix product for each input/output in a, b,
1132   // and c, which contain batch_count DeviceMemory objects.
1133   virtual bool DoBlasGemmBatched(
1134       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1135       uint64 n, uint64 k, float alpha,
1136       const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
1137       const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb,
1138       float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,
1139       int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0;
1140   virtual bool DoBlasGemmBatched(
1141       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1142       uint64 n, uint64 k, float alpha,
1143       const port::ArraySlice<DeviceMemory<float> *> &a, int lda,
1144       const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, float beta,
1145       const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
1146       int batch_count, ScratchAllocator *scratch_allocator) = 0;
1147   virtual bool DoBlasGemmBatched(
1148       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1149       uint64 n, uint64 k, double alpha,
1150       const port::ArraySlice<DeviceMemory<double> *> &a, int lda,
1151       const port::ArraySlice<DeviceMemory<double> *> &b, int ldb, double beta,
1152       const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
1153       int batch_count, ScratchAllocator *scratch_allocator) = 0;
1154   virtual bool DoBlasGemmBatched(
1155       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1156       uint64 n, uint64 k, std::complex<float> alpha,
1157       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
1158       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
1159       std::complex<float> beta,
1160       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
1161       int batch_count, ScratchAllocator *scratch_allocator) = 0;
1162   virtual bool DoBlasGemmBatched(
1163       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1164       uint64 n, uint64 k, std::complex<double> alpha,
1165       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
1166       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
1167       std::complex<double> beta,
1168       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
1169       int batch_count, ScratchAllocator *scratch_allocator) = 0;
1170 
1171   // Batched gemm with strides instead of pointer arrays.
1172   virtual port::Status DoBlasGemmStridedBatched(
1173       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1174       uint64 n, uint64 k, DataType dtype, const void *alpha,
1175       const DeviceMemoryBase &a, int lda, int64_t stride_a,
1176       const DeviceMemoryBase &b, int ldb, int64_t stride_b, const void *beta,
1177       DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count) = 0;
1178 
1179   // Computes a matrix-matrix product where one input matrix is Hermitian:
1180   //
1181   //     c <- alpha * a * b + beta * c,
1182   // or
1183   //     c <- alpha * b * a + beta * c,
1184   //
1185   // alpha and beta are scalars; a is a Hermitian matrix; b and c are m-by-n
1186   // matrices.
1187   virtual bool DoBlasHemm(Stream *stream, blas::Side side,
1188                           blas::UpperLower uplo, uint64 m, uint64 n,
1189                           std::complex<float> alpha,
1190                           const DeviceMemory<std::complex<float>> &a, int lda,
1191                           const DeviceMemory<std::complex<float>> &b, int ldb,
1192                           std::complex<float> beta,
1193                           DeviceMemory<std::complex<float>> *c, int ldc) = 0;
1194   virtual bool DoBlasHemm(Stream *stream, blas::Side side,
1195                           blas::UpperLower uplo, uint64 m, uint64 n,
1196                           std::complex<double> alpha,
1197                           const DeviceMemory<std::complex<double>> &a, int lda,
1198                           const DeviceMemory<std::complex<double>> &b, int ldb,
1199                           std::complex<double> beta,
1200                           DeviceMemory<std::complex<double>> *c, int ldc) = 0;
1201 
1202   // Performs a Hermitian rank-k update.
1203   //
1204   //     c <- alpha * a * conj(a') + beta * c,
1205   // or
1206   //     c <- alpha * conj(a') * a + beta * c,
1207   //
1208   // alpha and beta are scalars; c is a n-by-n Hermitian matrix; a is an n-by-k
1209   // matrix in the first case and a k-by-n matrix in the second case.
1210   virtual bool DoBlasHerk(Stream *stream, blas::UpperLower uplo,
1211                           blas::Transpose trans, uint64 n, uint64 k,
1212                           float alpha,
1213                           const DeviceMemory<std::complex<float>> &a, int lda,
1214                           float beta, DeviceMemory<std::complex<float>> *c,
1215                           int ldc) = 0;
1216   virtual bool DoBlasHerk(Stream *stream, blas::UpperLower uplo,
1217                           blas::Transpose trans, uint64 n, uint64 k,
1218                           double alpha,
1219                           const DeviceMemory<std::complex<double>> &a, int lda,
1220                           double beta, DeviceMemory<std::complex<double>> *c,
1221                           int ldc) = 0;
1222 
1223   // Performs a Hermitian rank-2k update.
1224   //
1225   //     c <- alpha * a * conj(b') + conj(alpha) * b * conj(a') + beta * c,
1226   // or
1227   //     c <- alpha * conj(b') * a + conj(alpha) * conj(a') * b + beta * c,
1228   //
1229   // alpha and beta are scalars; c is a n-by-n Hermitian matrix; a and b are
1230   // n-by-k matrices in the first case and k-by-n matrices in the second case.
1231   virtual bool DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
1232                            blas::Transpose trans, uint64 n, uint64 k,
1233                            std::complex<float> alpha,
1234                            const DeviceMemory<std::complex<float>> &a, int lda,
1235                            const DeviceMemory<std::complex<float>> &b, int ldb,
1236                            float beta, DeviceMemory<std::complex<float>> *c,
1237                            int ldc) = 0;
1238   virtual bool DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
1239                            blas::Transpose trans, uint64 n, uint64 k,
1240                            std::complex<double> alpha,
1241                            const DeviceMemory<std::complex<double>> &a, int lda,
1242                            const DeviceMemory<std::complex<double>> &b, int ldb,
1243                            double beta, DeviceMemory<std::complex<double>> *c,
1244                            int ldc) = 0;
1245 
1246   // Computes a matrix-matrix product where one input matrix is symmetric.
1247   //
1248   //     c <- alpha * a * b + beta * c,
1249   // or
1250   //     c <- alpha * b * a + beta * c,
1251   //
1252   // alpha and beta are scalars; a is a symmetric matrix; b and c are m-by-n
1253   // matrices.
1254   virtual bool DoBlasSymm(Stream *stream, blas::Side side,
1255                           blas::UpperLower uplo, uint64 m, uint64 n,
1256                           float alpha, const DeviceMemory<float> &a, int lda,
1257                           const DeviceMemory<float> &b, int ldb, float beta,
1258                           DeviceMemory<float> *c, int ldc) = 0;
1259   virtual bool DoBlasSymm(Stream *stream, blas::Side side,
1260                           blas::UpperLower uplo, uint64 m, uint64 n,
1261                           double alpha, const DeviceMemory<double> &a, int lda,
1262                           const DeviceMemory<double> &b, int ldb, double beta,
1263                           DeviceMemory<double> *c, int ldc) = 0;
1264   virtual bool DoBlasSymm(Stream *stream, blas::Side side,
1265                           blas::UpperLower uplo, uint64 m, uint64 n,
1266                           std::complex<float> alpha,
1267                           const DeviceMemory<std::complex<float>> &a, int lda,
1268                           const DeviceMemory<std::complex<float>> &b, int ldb,
1269                           std::complex<float> beta,
1270                           DeviceMemory<std::complex<float>> *c, int ldc) = 0;
1271   virtual bool DoBlasSymm(Stream *stream, blas::Side side,
1272                           blas::UpperLower uplo, uint64 m, uint64 n,
1273                           std::complex<double> alpha,
1274                           const DeviceMemory<std::complex<double>> &a, int lda,
1275                           const DeviceMemory<std::complex<double>> &b, int ldb,
1276                           std::complex<double> beta,
1277                           DeviceMemory<std::complex<double>> *c, int ldc) = 0;
1278 
1279   // Performs a symmetric rank-k update.
1280   //
1281   //     c <- alpha * a * a' + beta * c,
1282   // or
1283   //     c <- alpha * a' * a + beta * c,
1284   //
1285   // alpha and beta are scalars; c is a n-by-n symmetric matrix; a is an n-by-k
1286   // matrix in the first case and a k-by-n matrix in the second case.
1287   virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
1288                           blas::Transpose trans, uint64 n, uint64 k,
1289                           float alpha, const DeviceMemory<float> &a, int lda,
1290                           float beta, DeviceMemory<float> *c, int ldc) = 0;
1291   virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
1292                           blas::Transpose trans, uint64 n, uint64 k,
1293                           double alpha, const DeviceMemory<double> &a, int lda,
1294                           double beta, DeviceMemory<double> *c, int ldc) = 0;
1295   virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
1296                           blas::Transpose trans, uint64 n, uint64 k,
1297                           std::complex<float> alpha,
1298                           const DeviceMemory<std::complex<float>> &a, int lda,
1299                           std::complex<float> beta,
1300                           DeviceMemory<std::complex<float>> *c, int ldc) = 0;
1301   virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
1302                           blas::Transpose trans, uint64 n, uint64 k,
1303                           std::complex<double> alpha,
1304                           const DeviceMemory<std::complex<double>> &a, int lda,
1305                           std::complex<double> beta,
1306                           DeviceMemory<std::complex<double>> *c, int ldc) = 0;
1307 
1308   // Performs a symmetric rank-2k update.
1309   //
1310   //     c <- alpha * a * b' + alpha * b * a' + beta * c,
1311   // or
1312   //     c <- alpha * b' * a + alpha * a' * b + beta * c,
1313   //
1314   // alpha and beta are scalars; c is a n-by-n symmetric matrix; a and b are
1315   // n-by-k matrices in the first case and k-by-n matrices in the second case.
1316   virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
1317                            blas::Transpose trans, uint64 n, uint64 k,
1318                            float alpha, const DeviceMemory<float> &a, int lda,
1319                            const DeviceMemory<float> &b, int ldb, float beta,
1320                            DeviceMemory<float> *c, int ldc) = 0;
1321   virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
1322                            blas::Transpose trans, uint64 n, uint64 k,
1323                            double alpha, const DeviceMemory<double> &a, int lda,
1324                            const DeviceMemory<double> &b, int ldb, double beta,
1325                            DeviceMemory<double> *c, int ldc) = 0;
1326   virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
1327                            blas::Transpose trans, uint64 n, uint64 k,
1328                            std::complex<float> alpha,
1329                            const DeviceMemory<std::complex<float>> &a, int lda,
1330                            const DeviceMemory<std::complex<float>> &b, int ldb,
1331                            std::complex<float> beta,
1332                            DeviceMemory<std::complex<float>> *c, int ldc) = 0;
1333   virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
1334                            blas::Transpose trans, uint64 n, uint64 k,
1335                            std::complex<double> alpha,
1336                            const DeviceMemory<std::complex<double>> &a, int lda,
1337                            const DeviceMemory<std::complex<double>> &b, int ldb,
1338                            std::complex<double> beta,
1339                            DeviceMemory<std::complex<double>> *c, int ldc) = 0;
1340 
1341   // Computes a matrix-matrix product where one input matrix is triangular.
1342   //
1343   //     b <- alpha * op(a) * b,
1344   // or
1345   //     b <- alpha * b * op(a)
1346   //
1347   // alpha is a scalar; b is an m-by-n matrix; a is a unit, or non-unit, upper
1348   // or lower triangular matrix; op(a) is one of op(a) = a, or op(a) = a', or
1349   // op(a) = conj(a').
1350   virtual bool DoBlasTrmm(Stream *stream, blas::Side side,
1351                           blas::UpperLower uplo, blas::Transpose transa,
1352                           blas::Diagonal diag, uint64 m, uint64 n, float alpha,
1353                           const DeviceMemory<float> &a, int lda,
1354                           DeviceMemory<float> *b, int ldb) = 0;
1355   virtual bool DoBlasTrmm(Stream *stream, blas::Side side,
1356                           blas::UpperLower uplo, blas::Transpose transa,
1357                           blas::Diagonal diag, uint64 m, uint64 n, double alpha,
1358                           const DeviceMemory<double> &a, int lda,
1359                           DeviceMemory<double> *b, int ldb) = 0;
1360   virtual bool DoBlasTrmm(Stream *stream, blas::Side side,
1361                           blas::UpperLower uplo, blas::Transpose transa,
1362                           blas::Diagonal diag, uint64 m, uint64 n,
1363                           std::complex<float> alpha,
1364                           const DeviceMemory<std::complex<float>> &a, int lda,
1365                           DeviceMemory<std::complex<float>> *b, int ldb) = 0;
1366   virtual bool DoBlasTrmm(Stream *stream, blas::Side side,
1367                           blas::UpperLower uplo, blas::Transpose transa,
1368                           blas::Diagonal diag, uint64 m, uint64 n,
1369                           std::complex<double> alpha,
1370                           const DeviceMemory<std::complex<double>> &a, int lda,
1371                           DeviceMemory<std::complex<double>> *b, int ldb) = 0;
1372 
1373   // Solves a triangular matrix equation.
1374   //
1375   //     op(a) * x = alpha * b,
1376   // or
1377   //     x * op(a) = alpha * b
1378   //
1379   // alpha is a scalar; x and b are m-by-n matrices; a is a unit, or non-unit,
1380   // upper or lower triangular matrix; op(a) is one of op(a) = a, or op(a) = a',
1381   // or op(a) = conj(a').
1382   virtual bool DoBlasTrsm(Stream *stream, blas::Side side,
1383                           blas::UpperLower uplo, blas::Transpose transa,
1384                           blas::Diagonal diag, uint64 m, uint64 n, float alpha,
1385                           const DeviceMemory<float> &a, int lda,
1386                           DeviceMemory<float> *b, int ldb) = 0;
1387   virtual bool DoBlasTrsm(Stream *stream, blas::Side side,
1388                           blas::UpperLower uplo, blas::Transpose transa,
1389                           blas::Diagonal diag, uint64 m, uint64 n, double alpha,
1390                           const DeviceMemory<double> &a, int lda,
1391                           DeviceMemory<double> *b, int ldb) = 0;
1392   virtual bool DoBlasTrsm(Stream *stream, blas::Side side,
1393                           blas::UpperLower uplo, blas::Transpose transa,
1394                           blas::Diagonal diag, uint64 m, uint64 n,
1395                           std::complex<float> alpha,
1396                           const DeviceMemory<std::complex<float>> &a, int lda,
1397                           DeviceMemory<std::complex<float>> *b, int ldb) = 0;
1398   virtual bool DoBlasTrsm(Stream *stream, blas::Side side,
1399                           blas::UpperLower uplo, blas::Transpose transa,
1400                           blas::Diagonal diag, uint64 m, uint64 n,
1401                           std::complex<double> alpha,
1402                           const DeviceMemory<std::complex<double>> &a, int lda,
1403                           DeviceMemory<std::complex<double>> *b, int ldb) = 0;
1404 
1405   // Creates a backend-specific plan object for a blaslt matmul operation, which
1406   // can then be passed to DoBlasLtMatmul(). When possible, plans should be
1407   // created once and reused for multiple calls to DoBlasLtMatmul().
1408   virtual port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
1409   CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams &params) = 0;
1410 
1411   // Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
1412   // returned in the order of increasing estimated compute time according to an
1413   // internal heuristic. The first returned algorithm can be used as the default
1414   // algorithm if no autotuning is to be performed.
1415   virtual port::StatusOr<
1416       std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
1417   GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
1418                             size_t max_workspace_size,
1419                             int max_algorithm_count) = 0;
1420 
1421   // Executes a blaslt matmul operation on the stream. If output_profile_result
1422   // is not nullptr, the operation is profiled, error messages are
1423   // suppressed, and output_profile_result->algorithm() is set to
1424   // algorithm->index(). If epilogue was set to kBias or kBiasThenReLU when
1425   // creating the plan, the bias argument here must refer to a valid device
1426   // vector of length equal to the number of rows in matrix c. If epilogue was
1427   // set to any other value then the bias argument here must be null. The bias
1428   // vector is broadcast across the batch dimension.
1429   // Note that the data types of a and b (c and bias) must match the ab_type
1430   // (c_type) with which the plan was created, and the data types of alpha and
1431   // beta must match the data type of c.
1432   virtual bool DoBlasLtMatmul(
1433       Stream *stream, const blas::IBlasLtMatmulPlan *plan,
1434       const HostOrDeviceScalar<void> &alpha, DeviceMemoryBase a,
1435       DeviceMemoryBase b, const HostOrDeviceScalar<void> &beta,
1436       DeviceMemoryBase c, ScratchAllocator *scratch_allocator,
1437       const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias,
1438       blas::ProfileResult *output_profile_result) = 0;
1439 
1440   template <typename ABType, typename CType>
1441   bool DoBlasLtMatmul(Stream *stream, const blas::IBlasLtMatmulPlan *plan,
1442                       const HostOrDeviceScalar<CType> &alpha,
1443                       const DeviceMemory<ABType> &a,
1444                       const DeviceMemory<ABType> &b,
1445                       const HostOrDeviceScalar<CType> &beta,
1446                       DeviceMemory<CType> *c,
1447                       ScratchAllocator *scratch_allocator,
1448                       const blas::IBlasLtMatmulAlgorithm *algorithm,
1449                       const DeviceMemory<CType> &bias = {},
1450                       blas::ProfileResult *output_profile_result = nullptr) {
1451     constexpr blas::DataType ab_type = blas::ToDataType<ABType>::value;
1452     if (ab_type != plan->ab_type()) {
1453       VLOG(2) << "DoBlasLtMatmul returning false because a and b type does "
1454                  "not match plan: expected "
1455               << plan->ab_type() << ", got " << ab_type;
1456       return false;
1457     }
1458     constexpr blas::DataType c_type = blas::ToDataType<CType>::value;
1459     if (c_type != plan->c_type()) {
1460       VLOG(2) << "DoBlasLtMatmul returning false because c type does "
1461                  "not match plan: expected "
1462               << plan->c_type() << ", got " << c_type;
1463       return false;
1464     }
1465     return DoBlasLtMatmul(stream, plan, alpha, a, b, beta, *c,
1466                           scratch_allocator, algorithm, bias,
1467                           output_profile_result);
1468   }
1469 
1470   virtual port::Status GetVersion(std::string *version) = 0;
1471 
1472  protected:
1473   BlasSupport() {}
1474 
1475  private:
1476   SE_DISALLOW_COPY_AND_ASSIGN(BlasSupport);
1477 };
1478 
1479 // Macro used to quickly declare overrides for abstract virtuals in the
1480 // BlasSupport base class.
1481 #define TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES                  \
1482   bool DoBlasAsum(Stream *stream, uint64 elem_count,                           \
1483                   const DeviceMemory<float> &x, int incx,                      \
1484                   DeviceMemory<float> *result) override;                       \
1485   bool DoBlasAsum(Stream *stream, uint64 elem_count,                           \
1486                   const DeviceMemory<double> &x, int incx,                     \
1487                   DeviceMemory<double> *result) override;                      \
1488   bool DoBlasAsum(Stream *stream, uint64 elem_count,                           \
1489                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1490                   DeviceMemory<float> *result) override;                       \
1491   bool DoBlasAsum(Stream *stream, uint64 elem_count,                           \
1492                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1493                   DeviceMemory<double> *result) override;                      \
1494   bool DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha,              \
1495                   const DeviceMemory<float> &x, int incx,                      \
1496                   DeviceMemory<float> *y, int incy) override;                  \
1497   bool DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha,             \
1498                   const DeviceMemory<double> &x, int incx,                     \
1499                   DeviceMemory<double> *y, int incy) override;                 \
1500   bool DoBlasAxpy(Stream *stream, uint64 elem_count,                           \
1501                   std::complex<float> alpha,                                   \
1502                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1503                   DeviceMemory<std::complex<float>> *y, int incy) override;    \
1504   bool DoBlasAxpy(Stream *stream, uint64 elem_count,                           \
1505                   std::complex<double> alpha,                                  \
1506                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1507                   DeviceMemory<std::complex<double>> *y, int incy) override;   \
1508   bool DoBlasCopy(Stream *stream, uint64 elem_count,                           \
1509                   const DeviceMemory<float> &x, int incx,                      \
1510                   DeviceMemory<float> *y, int incy) override;                  \
1511   bool DoBlasCopy(Stream *stream, uint64 elem_count,                           \
1512                   const DeviceMemory<double> &x, int incx,                     \
1513                   DeviceMemory<double> *y, int incy) override;                 \
1514   bool DoBlasCopy(Stream *stream, uint64 elem_count,                           \
1515                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1516                   DeviceMemory<std::complex<float>> *y, int incy) override;    \
1517   bool DoBlasCopy(Stream *stream, uint64 elem_count,                           \
1518                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1519                   DeviceMemory<std::complex<double>> *y, int incy) override;   \
1520   bool DoBlasDot(Stream *stream, uint64 elem_count,                            \
1521                  const DeviceMemory<float> &x, int incx,                       \
1522                  const DeviceMemory<float> &y, int incy,                       \
1523                  DeviceMemory<float> *result) override;                        \
1524   bool DoBlasDot(Stream *stream, uint64 elem_count,                            \
1525                  const DeviceMemory<double> &x, int incx,                      \
1526                  const DeviceMemory<double> &y, int incy,                      \
1527                  DeviceMemory<double> *result) override;                       \
1528   bool DoBlasDotc(Stream *stream, uint64 elem_count,                           \
1529                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1530                   const DeviceMemory<std::complex<float>> &y, int incy,        \
1531                   DeviceMemory<std::complex<float>> *result) override;         \
1532   bool DoBlasDotc(Stream *stream, uint64 elem_count,                           \
1533                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1534                   const DeviceMemory<std::complex<double>> &y, int incy,       \
1535                   DeviceMemory<std::complex<double>> *result) override;        \
1536   bool DoBlasDotu(Stream *stream, uint64 elem_count,                           \
1537                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1538                   const DeviceMemory<std::complex<float>> &y, int incy,        \
1539                   DeviceMemory<std::complex<float>> *result) override;         \
1540   bool DoBlasDotu(Stream *stream, uint64 elem_count,                           \
1541                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1542                   const DeviceMemory<std::complex<double>> &y, int incy,       \
1543                   DeviceMemory<std::complex<double>> *result) override;        \
1544   bool DoBlasNrm2(Stream *stream, uint64 elem_count,                           \
1545                   const DeviceMemory<float> &x, int incx,                      \
1546                   DeviceMemory<float> *result) override;                       \
1547   bool DoBlasNrm2(Stream *stream, uint64 elem_count,                           \
1548                   const DeviceMemory<double> &x, int incx,                     \
1549                   DeviceMemory<double> *result) override;                      \
1550   bool DoBlasNrm2(Stream *stream, uint64 elem_count,                           \
1551                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1552                   DeviceMemory<float> *result) override;                       \
1553   bool DoBlasNrm2(Stream *stream, uint64 elem_count,                           \
1554                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1555                   DeviceMemory<double> *result) override;                      \
1556   bool DoBlasRot(Stream *stream, uint64 elem_count, DeviceMemory<float> *x,    \
1557                  int incx, DeviceMemory<float> *y, int incy, float c, float s) \
1558       override;                                                                \
1559   bool DoBlasRot(Stream *stream, uint64 elem_count, DeviceMemory<double> *x,   \
1560                  int incx, DeviceMemory<double> *y, int incy, double c,        \
1561                  double s) override;                                           \
1562   bool DoBlasRot(Stream *stream, uint64 elem_count,                            \
1563                  DeviceMemory<std::complex<float>> *x, int incx,               \
1564                  DeviceMemory<std::complex<float>> *y, int incy, float c,      \
1565                  float s) override;                                            \
1566   bool DoBlasRot(Stream *stream, uint64 elem_count,                            \
1567                  DeviceMemory<std::complex<double>> *x, int incx,              \
1568                  DeviceMemory<std::complex<double>> *y, int incy, double c,    \
1569                  double s) override;                                           \
1570   bool DoBlasRotg(Stream *stream, DeviceMemory<float> *a,                      \
1571                   DeviceMemory<float> *b, DeviceMemory<float> *c,              \
1572                   DeviceMemory<float> *s) override;                            \
1573   bool DoBlasRotg(Stream *stream, DeviceMemory<double> *a,                     \
1574                   DeviceMemory<double> *b, DeviceMemory<double> *c,            \
1575                   DeviceMemory<double> *s) override;                           \
1576   bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<float>> *a,        \
1577                   DeviceMemory<std::complex<float>> *b,                        \
1578                   DeviceMemory<float> *c,                                      \
1579                   DeviceMemory<std::complex<float>> *s) override;              \
1580   bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<double>> *a,       \
1581                   DeviceMemory<std::complex<double>> *b,                       \
1582                   DeviceMemory<double> *c,                                     \
1583                   DeviceMemory<std::complex<double>> *s) override;             \
1584   bool DoBlasRotm(Stream *stream, uint64 elem_count, DeviceMemory<float> *x,   \
1585                   int incx, DeviceMemory<float> *y, int incy,                  \
1586                   const DeviceMemory<float> &param) override;                  \
1587   bool DoBlasRotm(Stream *stream, uint64 elem_count, DeviceMemory<double> *x,  \
1588                   int incx, DeviceMemory<double> *y, int incy,                 \
1589                   const DeviceMemory<double> &param) override;                 \
1590   bool DoBlasRotmg(Stream *stream, DeviceMemory<float> *d1,                    \
1591                    DeviceMemory<float> *d2, DeviceMemory<float> *x1,           \
1592                    const DeviceMemory<float> &y1, DeviceMemory<float> *param)  \
1593       override;                                                                \
1594   bool DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1,                   \
1595                    DeviceMemory<double> *d2, DeviceMemory<double> *x1,         \
1596                    const DeviceMemory<double> &y1,                             \
1597                    DeviceMemory<double> *param) override;                      \
1598   bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha,              \
1599                   DeviceMemory<float> *x, int incx) override;                  \
1600   bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha,             \
1601                   DeviceMemory<double> *x, int incx) override;                 \
1602   bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha,              \
1603                   DeviceMemory<std::complex<float>> *x, int incx) override;    \
1604   bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha,             \
1605                   DeviceMemory<std::complex<double>> *x, int incx) override;   \
1606   bool DoBlasScal(Stream *stream, uint64 elem_count,                           \
1607                   std::complex<float> alpha,                                   \
1608                   DeviceMemory<std::complex<float>> *x, int incx) override;    \
1609   bool DoBlasScal(Stream *stream, uint64 elem_count,                           \
1610                   std::complex<double> alpha,                                  \
1611                   DeviceMemory<std::complex<double>> *x, int incx) override;   \
1612   bool DoBlasSwap(Stream *stream, uint64 elem_count, DeviceMemory<float> *x,   \
1613                   int incx, DeviceMemory<float> *y, int incy) override;        \
1614   bool DoBlasSwap(Stream *stream, uint64 elem_count, DeviceMemory<double> *x,  \
1615                   int incx, DeviceMemory<double> *y, int incy) override;       \
1616   bool DoBlasSwap(Stream *stream, uint64 elem_count,                           \
1617                   DeviceMemory<std::complex<float>> *x, int incx,              \
1618                   DeviceMemory<std::complex<float>> *y, int incy) override;    \
1619   bool DoBlasSwap(Stream *stream, uint64 elem_count,                           \
1620                   DeviceMemory<std::complex<double>> *x, int incx,             \
1621                   DeviceMemory<std::complex<double>> *y, int incy) override;   \
1622   bool DoBlasIamax(Stream *stream, uint64 elem_count,                          \
1623                    const DeviceMemory<float> &x, int incx,                     \
1624                    DeviceMemory<int> *result) override;                        \
1625   bool DoBlasIamax(Stream *stream, uint64 elem_count,                          \
1626                    const DeviceMemory<double> &x, int incx,                    \
1627                    DeviceMemory<int> *result) override;                        \
1628   bool DoBlasIamax(Stream *stream, uint64 elem_count,                          \
1629                    const DeviceMemory<std::complex<float>> &x, int incx,       \
1630                    DeviceMemory<int> *result) override;                        \
1631   bool DoBlasIamax(Stream *stream, uint64 elem_count,                          \
1632                    const DeviceMemory<std::complex<double>> &x, int incx,      \
1633                    DeviceMemory<int> *result) override;                        \
1634   bool DoBlasIamin(Stream *stream, uint64 elem_count,                          \
1635                    const DeviceMemory<float> &x, int incx,                     \
1636                    DeviceMemory<int> *result) override;                        \
1637   bool DoBlasIamin(Stream *stream, uint64 elem_count,                          \
1638                    const DeviceMemory<double> &x, int incx,                    \
1639                    DeviceMemory<int> *result) override;                        \
1640   bool DoBlasIamin(Stream *stream, uint64 elem_count,                          \
1641                    const DeviceMemory<std::complex<float>> &x, int incx,       \
1642                    DeviceMemory<int> *result) override;                        \
1643   bool DoBlasIamin(Stream *stream, uint64 elem_count,                          \
1644                    const DeviceMemory<std::complex<double>> &x, int incx,      \
1645                    DeviceMemory<int> *result) override;                        \
1646   bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1647                   uint64 kl, uint64 ku, float alpha,                           \
1648                   const DeviceMemory<float> &a, int lda,                       \
1649                   const DeviceMemory<float> &x, int incx, float beta,          \
1650                   DeviceMemory<float> *y, int incy) override;                  \
1651   bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1652                   uint64 kl, uint64 ku, double alpha,                          \
1653                   const DeviceMemory<double> &a, int lda,                      \
1654                   const DeviceMemory<double> &x, int incx, double beta,        \
1655                   DeviceMemory<double> *y, int incy) override;                 \
1656   bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1657                   uint64 kl, uint64 ku, std::complex<float> alpha,             \
1658                   const DeviceMemory<std::complex<float>> &a, int lda,         \
1659                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1660                   std::complex<float> beta,                                    \
1661                   DeviceMemory<std::complex<float>> *y, int incy) override;    \
1662   bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1663                   uint64 kl, uint64 ku, std::complex<double> alpha,            \
1664                   const DeviceMemory<std::complex<double>> &a, int lda,        \
1665                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1666                   std::complex<double> beta,                                   \
1667                   DeviceMemory<std::complex<double>> *y, int incy) override;   \
1668   bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1669                   float alpha, const DeviceMemory<float> &a, int lda,          \
1670                   const DeviceMemory<float> &x, int incx, float beta,          \
1671                   DeviceMemory<float> *y, int incy) override;                  \
1672   bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1673                   double alpha, const DeviceMemory<double> &a, int lda,        \
1674                   const DeviceMemory<double> &x, int incx, double beta,        \
1675                   DeviceMemory<double> *y, int incy) override;                 \
1676   bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1677                   std::complex<float> alpha,                                   \
1678                   const DeviceMemory<std::complex<float>> &a, int lda,         \
1679                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1680                   std::complex<float> beta,                                    \
1681                   DeviceMemory<std::complex<float>> *y, int incy) override;    \
1682   bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1683                   std::complex<double> alpha,                                  \
1684                   const DeviceMemory<std::complex<double>> &a, int lda,        \
1685                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1686                   std::complex<double> beta,                                   \
1687                   DeviceMemory<std::complex<double>> *y, int incy) override;   \
1688   bool DoBlasGemvWithProfiling(                                                \
1689       Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha,  \
1690       const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,     \
1691       int incx, float beta, DeviceMemory<float> *y, int incy,                  \
1692       blas::ProfileResult *output_profile_result) override;                    \
1693   bool DoBlasGemvWithProfiling(                                                \
1694       Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha, \
1695       const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,   \
1696       int incx, double beta, DeviceMemory<double> *y, int incy,                \
1697       blas::ProfileResult *output_profile_result) override;                    \
1698   bool DoBlasGemvWithProfiling(                                                \
1699       Stream *stream, blas::Transpose trans, uint64 m, uint64 n,               \
1700       std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a,   \
1701       int lda, const DeviceMemory<std::complex<float>> &x, int incx,           \
1702       std::complex<float> beta, DeviceMemory<std::complex<float>> *y,          \
1703       int incy, blas::ProfileResult *output_profile_result) override;          \
1704   bool DoBlasGemvWithProfiling(                                                \
1705       Stream *stream, blas::Transpose trans, uint64 m, uint64 n,               \
1706       std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a, \
1707       int lda, const DeviceMemory<std::complex<double>> &x, int incx,          \
1708       std::complex<double> beta, DeviceMemory<std::complex<double>> *y,        \
1709       int incy, blas::ProfileResult *output_profile_result) override;          \
1710   bool DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha,              \
1711                  const DeviceMemory<float> &x, int incx,                       \
1712                  const DeviceMemory<float> &y, int incy,                       \
1713                  DeviceMemory<float> *a, int lda) override;                    \
1714   bool DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha,             \
1715                  const DeviceMemory<double> &x, int incx,                      \
1716                  const DeviceMemory<double> &y, int incy,                      \
1717                  DeviceMemory<double> *a, int lda) override;                   \
1718   bool DoBlasGerc(Stream *stream, uint64 m, uint64 n,                          \
1719                   std::complex<float> alpha,                                   \
1720                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1721                   const DeviceMemory<std::complex<float>> &y, int incy,        \
1722                   DeviceMemory<std::complex<float>> *a, int lda) override;     \
1723   bool DoBlasGerc(Stream *stream, uint64 m, uint64 n,                          \
1724                   std::complex<double> alpha,                                  \
1725                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1726                   const DeviceMemory<std::complex<double>> &y, int incy,       \
1727                   DeviceMemory<std::complex<double>> *a, int lda) override;    \
1728   bool DoBlasGeru(Stream *stream, uint64 m, uint64 n,                          \
1729                   std::complex<float> alpha,                                   \
1730                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1731                   const DeviceMemory<std::complex<float>> &y, int incy,        \
1732                   DeviceMemory<std::complex<float>> *a, int lda) override;     \
1733   bool DoBlasGeru(Stream *stream, uint64 m, uint64 n,                          \
1734                   std::complex<double> alpha,                                  \
1735                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1736                   const DeviceMemory<std::complex<double>> &y, int incy,       \
1737                   DeviceMemory<std::complex<double>> *a, int lda) override;    \
1738   bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k,   \
1739                   std::complex<float> alpha,                                   \
1740                   const DeviceMemory<std::complex<float>> &a, int lda,         \
1741                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1742                   std::complex<float> beta,                                    \
1743                   DeviceMemory<std::complex<float>> *y, int incy) override;    \
1744   bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k,   \
1745                   std::complex<double> alpha,                                  \
1746                   const DeviceMemory<std::complex<double>> &a, int lda,        \
1747                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1748                   std::complex<double> beta,                                   \
1749                   DeviceMemory<std::complex<double>> *y, int incy) override;   \
1750   bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1751                   std::complex<float> alpha,                                   \
1752                   const DeviceMemory<std::complex<float>> &a, int lda,         \
1753                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1754                   std::complex<float> beta,                                    \
1755                   DeviceMemory<std::complex<float>> *y, int incy) override;    \
1756   bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1757                   std::complex<double> alpha,                                  \
1758                   const DeviceMemory<std::complex<double>> &a, int lda,        \
1759                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1760                   std::complex<double> beta,                                   \
1761                   DeviceMemory<std::complex<double>> *y, int incy) override;   \
1762   bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \
1763                  const DeviceMemory<std::complex<float>> &x, int incx,         \
1764                  DeviceMemory<std::complex<float>> *a, int lda) override;      \
1765   bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,              \
1766                  double alpha, const DeviceMemory<std::complex<double>> &x,    \
1767                  int incx, DeviceMemory<std::complex<double>> *a, int lda)     \
1768       override;                                                                \
1769   bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1770                   std::complex<float> alpha,                                   \
1771                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1772                   const DeviceMemory<std::complex<float>> &y, int incy,        \
1773                   DeviceMemory<std::complex<float>> *a, int lda) override;     \
1774   bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1775                   std::complex<double> alpha,                                  \
1776                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1777                   const DeviceMemory<std::complex<double>> &y, int incy,       \
1778                   DeviceMemory<std::complex<double>> *a, int lda) override;    \
1779   bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1780                   std::complex<float> alpha,                                   \
1781                   const DeviceMemory<std::complex<float>> &ap,                 \
1782                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1783                   std::complex<float> beta,                                    \
1784                   DeviceMemory<std::complex<float>> *y, int incy) override;    \
1785   bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1786                   std::complex<double> alpha,                                  \
1787                   const DeviceMemory<std::complex<double>> &ap,                \
1788                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1789                   std::complex<double> beta,                                   \
1790                   DeviceMemory<std::complex<double>> *y, int incy) override;   \
1791   bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \
1792                  const DeviceMemory<std::complex<float>> &x, int incx,         \
1793                  DeviceMemory<std::complex<float>> *ap) override;              \
1794   bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,              \
1795                  double alpha, const DeviceMemory<std::complex<double>> &x,    \
1796                  int incx, DeviceMemory<std::complex<double>> *ap) override;   \
1797   bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1798                   std::complex<float> alpha,                                   \
1799                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1800                   const DeviceMemory<std::complex<float>> &y, int incy,        \
1801                   DeviceMemory<std::complex<float>> *ap) override;             \
1802   bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1803                   std::complex<double> alpha,                                  \
1804                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1805                   const DeviceMemory<std::complex<double>> &y, int incy,       \
1806                   DeviceMemory<std::complex<double>> *ap) override;            \
1807   bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k,   \
1808                   float alpha, const DeviceMemory<float> &a, int lda,          \
1809                   const DeviceMemory<float> &x, int incx, float beta,          \
1810                   DeviceMemory<float> *y, int incy) override;                  \
1811   bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k,   \
1812                   double alpha, const DeviceMemory<double> &a, int lda,        \
1813                   const DeviceMemory<double> &x, int incx, double beta,        \
1814                   DeviceMemory<double> *y, int incy) override;                 \
1815   bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1816                   float alpha, const DeviceMemory<float> &ap,                  \
1817                   const DeviceMemory<float> &x, int incx, float beta,          \
1818                   DeviceMemory<float> *y, int incy) override;                  \
1819   bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1820                   double alpha, const DeviceMemory<double> &ap,                \
1821                   const DeviceMemory<double> &x, int incx, double beta,        \
1822                   DeviceMemory<double> *y, int incy) override;                 \
1823   bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \
1824                  const DeviceMemory<float> &x, int incx,                       \
1825                  DeviceMemory<float> *ap) override;                            \
1826   bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,              \
1827                  double alpha, const DeviceMemory<double> &x, int incx,        \
1828                  DeviceMemory<double> *ap) override;                           \
1829   bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1830                   float alpha, const DeviceMemory<float> &x, int incx,         \
1831                   const DeviceMemory<float> &y, int incy,                      \
1832                   DeviceMemory<float> *ap) override;                           \
1833   bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1834                   double alpha, const DeviceMemory<double> &x, int incx,       \
1835                   const DeviceMemory<double> &y, int incy,                     \
1836                   DeviceMemory<double> *ap) override;                          \
1837   bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1838                   float alpha, const DeviceMemory<float> &a, int lda,          \
1839                   const DeviceMemory<float> &x, int incx, float beta,          \
1840                   DeviceMemory<float> *y, int incy) override;                  \
1841   bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1842                   double alpha, const DeviceMemory<double> &a, int lda,        \
1843                   const DeviceMemory<double> &x, int incx, double beta,        \
1844                   DeviceMemory<double> *y, int incy) override;                 \
1845   bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \
1846                  const DeviceMemory<float> &x, int incx,                       \
1847                  DeviceMemory<float> *a, int lda) override;                    \
1848   bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,              \
1849                  double alpha, const DeviceMemory<double> &x, int incx,        \
1850                  DeviceMemory<double> *a, int lda) override;                   \
1851   bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1852                   float alpha, const DeviceMemory<float> &x, int incx,         \
1853                   const DeviceMemory<float> &y, int incy,                      \
1854                   DeviceMemory<float> *a, int lda) override;                   \
1855   bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1856                   double alpha, const DeviceMemory<double> &x, int incx,       \
1857                   const DeviceMemory<double> &y, int incy,                     \
1858                   DeviceMemory<double> *a, int lda) override;                  \
1859   bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,                       \
1860                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1861                   uint64 k, const DeviceMemory<float> &a, int lda,             \
1862                   DeviceMemory<float> *x, int incx) override;                  \
1863   bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,                       \
1864                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1865                   uint64 k, const DeviceMemory<double> &a, int lda,            \
1866                   DeviceMemory<double> *x, int incx) override;                 \
1867   bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,                       \
1868                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1869                   uint64 k, const DeviceMemory<std::complex<float>> &a,        \
1870                   int lda, DeviceMemory<std::complex<float>> *x, int incx)     \
1871       override;                                                                \
1872   bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,                       \
1873                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1874                   uint64 k, const DeviceMemory<std::complex<double>> &a,       \
1875                   int lda, DeviceMemory<std::complex<double>> *x, int incx)    \
1876       override;                                                                \
1877   bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,                       \
1878                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1879                   uint64 k, const DeviceMemory<float> &a, int lda,             \
1880                   DeviceMemory<float> *x, int incx) override;                  \
1881   bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,                       \
1882                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1883                   uint64 k, const DeviceMemory<double> &a, int lda,            \
1884                   DeviceMemory<double> *x, int incx) override;                 \
1885   bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,                       \
1886                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1887                   uint64 k, const DeviceMemory<std::complex<float>> &a,        \
1888                   int lda, DeviceMemory<std::complex<float>> *x, int incx)     \
1889       override;                                                                \
1890   bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,                       \
1891                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1892                   uint64 k, const DeviceMemory<std::complex<double>> &a,       \
1893                   int lda, DeviceMemory<std::complex<double>> *x, int incx)    \
1894       override;                                                                \
1895   bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,                       \
1896                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1897                   const DeviceMemory<float> &ap, DeviceMemory<float> *x,       \
1898                   int incx) override;                                          \
1899   bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,                       \
1900                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1901                   const DeviceMemory<double> &ap, DeviceMemory<double> *x,     \
1902                   int incx) override;                                          \
1903   bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,                       \
1904                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1905                   const DeviceMemory<std::complex<float>> &ap,                 \
1906                   DeviceMemory<std::complex<float>> *x, int incx) override;    \
1907   bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,                       \
1908                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1909                   const DeviceMemory<std::complex<double>> &ap,                \
1910                   DeviceMemory<std::complex<double>> *x, int incx) override;   \
1911   bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,                       \
1912                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1913                   const DeviceMemory<float> &ap, DeviceMemory<float> *x,       \
1914                   int incx) override;                                          \
1915   bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,                       \
1916                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1917                   const DeviceMemory<double> &ap, DeviceMemory<double> *x,     \
1918                   int incx) override;                                          \
1919   bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,                       \
1920                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1921                   const DeviceMemory<std::complex<float>> &ap,                 \
1922                   DeviceMemory<std::complex<float>> *x, int incx) override;    \
1923   bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,                       \
1924                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1925                   const DeviceMemory<std::complex<double>> &ap,                \
1926                   DeviceMemory<std::complex<double>> *x, int incx) override;   \
1927   bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,                       \
1928                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1929                   const DeviceMemory<float> &a, int lda,                       \
1930                   DeviceMemory<float> *x, int incx) override;                  \
1931   bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,                       \
1932                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1933                   const DeviceMemory<double> &a, int lda,                      \
1934                   DeviceMemory<double> *x, int incx) override;                 \
1935   bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,                       \
1936                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1937                   const DeviceMemory<std::complex<float>> &a, int lda,         \
1938                   DeviceMemory<std::complex<float>> *x, int incx) override;    \
1939   bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,                       \
1940                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1941                   const DeviceMemory<std::complex<double>> &a, int lda,        \
1942                   DeviceMemory<std::complex<double>> *x, int incx) override;   \
1943   bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,                       \
1944                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1945                   const DeviceMemory<float> &a, int lda,                       \
1946                   DeviceMemory<float> *x, int incx) override;                  \
1947   bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,                       \
1948                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1949                   const DeviceMemory<double> &a, int lda,                      \
1950                   DeviceMemory<double> *x, int incx) override;                 \
1951   bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,                       \
1952                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1953                   const DeviceMemory<std::complex<float>> &a, int lda,         \
1954                   DeviceMemory<std::complex<float>> *x, int incx) override;    \
1955   bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,                       \
1956                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1957                   const DeviceMemory<std::complex<double>> &a, int lda,        \
1958                   DeviceMemory<std::complex<double>> *x, int incx) override;   \
1959   port::Status DoBlasGemm(                                                     \
1960       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
1961       uint64 m, uint64 n, uint64 k, blas::DataType dtype, const void *alpha,   \
1962       const DeviceMemoryBase &a, int lda, const DeviceMemoryBase &b, int ldb,  \
1963       const void *beta, DeviceMemoryBase *c, int ldc) override;                \
1964   bool DoBlasGemmWithProfiling(                                                \
1965       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
1966       uint64 m, uint64 n, uint64 k, float alpha,                               \
1967       const DeviceMemory<Eigen::half> &a, int lda,                             \
1968       const DeviceMemory<Eigen::half> &b, int ldb, float beta,                 \
1969       DeviceMemory<Eigen::half> *c, int ldc,                                   \
1970       blas::ProfileResult *output_profile_result) override;                    \
1971   bool DoBlasGemmWithProfiling(                                                \
1972       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
1973       uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, \
1974       int lda, const DeviceMemory<float> &b, int ldb, float beta,              \
1975       DeviceMemory<float> *c, int ldc,                                         \
1976       blas::ProfileResult *output_profile_result) override;                    \
1977   bool DoBlasGemmWithProfiling(                                                \
1978       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
1979       uint64 m, uint64 n, uint64 k, double alpha,                              \
1980       const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,   \
1981       int ldb, double beta, DeviceMemory<double> *c, int ldc,                  \
1982       blas::ProfileResult *output_profile_result) override;                    \
1983   bool DoBlasGemmWithProfiling(                                                \
1984       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
1985       uint64 m, uint64 n, uint64 k, std::complex<float> alpha,                 \
1986       const DeviceMemory<std::complex<float>> &a, int lda,                     \
1987       const DeviceMemory<std::complex<float>> &b, int ldb,                     \
1988       std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, \
1989       blas::ProfileResult *output_profile_result) override;                    \
1990   bool DoBlasGemmWithProfiling(                                                \
1991       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
1992       uint64 m, uint64 n, uint64 k, std::complex<double> alpha,                \
1993       const DeviceMemory<std::complex<double>> &a, int lda,                    \
1994       const DeviceMemory<std::complex<double>> &b, int ldb,                    \
1995       std::complex<double> beta, DeviceMemory<std::complex<double>> *c,        \
1996       int ldc, blas::ProfileResult *output_profile_result) override;           \
1997   bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms) \
1998       override;                                                                \
1999   port::Status DoBlasGemmWithAlgorithm(                                        \
2000       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2001       uint64 m, uint64 n, uint64 k, const void *alpha,                         \
2002       const DeviceMemoryBase &a, blas::DataType type_a, int lda,               \
2003       const DeviceMemoryBase &b, blas::DataType type_b, int ldb,               \
2004       const void *beta, DeviceMemoryBase *c, blas::DataType type_c, int ldc,   \
2005       blas::ComputationType computation_type, blas::AlgorithmType algorithm,   \
2006       blas::ProfileResult *output_profile_result) override;                    \
2007   bool DoBlasGemmBatched(                                                      \
2008       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2009       uint64 m, uint64 n, uint64 k, float alpha,                               \
2010       const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,         \
2011       const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb,         \
2012       float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,      \
2013       int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \
2014   bool DoBlasGemmBatched(                                                      \
2015       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2016       uint64 m, uint64 n, uint64 k, float alpha,                               \
2017       const port::ArraySlice<DeviceMemory<float> *> &a, int lda,               \
2018       const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, float beta,   \
2019       const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,               \
2020       int batch_count, ScratchAllocator *scratch_allocator) override;          \
2021   bool DoBlasGemmBatched(                                                      \
2022       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2023       uint64 m, uint64 n, uint64 k, double alpha,                              \
2024       const port::ArraySlice<DeviceMemory<double> *> &a, int lda,              \
2025       const port::ArraySlice<DeviceMemory<double> *> &b, int ldb, double beta, \
2026       const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,              \
2027       int batch_count, ScratchAllocator *scratch_allocator) override;          \
2028   bool DoBlasGemmBatched(                                                      \
2029       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2030       uint64 m, uint64 n, uint64 k, std::complex<float> alpha,                 \
2031       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda, \
2032       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb, \
2033       std::complex<float> beta,                                                \
2034       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc, \
2035       int batch_count, ScratchAllocator *scratch_allocator) override;          \
2036   bool DoBlasGemmBatched(                                                      \
2037       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2038       uint64 m, uint64 n, uint64 k, std::complex<double> alpha,                \
2039       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a,         \
2040       int lda,                                                                 \
2041       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b,         \
2042       int ldb, std::complex<double> beta,                                      \
2043       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c,         \
2044       int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \
2045   port::Status DoBlasGemmStridedBatched(                                       \
2046       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2047       uint64 m, uint64 n, uint64 k, blas::DataType dtype, const void *alpha,   \
2048       const DeviceMemoryBase &a, int lda, int64 stride_a,                      \
2049       const DeviceMemoryBase &b, int ldb, int64 stride_b, const void *beta,    \
2050       DeviceMemoryBase *c, int ldc, int64 stride_c, int batch_count);          \
2051   port::Status DoBlasGemmStridedBatchedWithAlgorithm(                          \
2052       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2053       uint64 m, uint64 n, uint64 k, const void *alpha,                         \
2054       const DeviceMemoryBase &a, blas::DataType type_a, int lda,               \
2055       int64 stride_a, const DeviceMemoryBase &b, blas::DataType type_b,        \
2056       int ldb, int64 stride_b, const void *beta, DeviceMemoryBase *c,          \
2057       blas::DataType type_c, int ldc, int64 stride_c, int batch_count,         \
2058       blas::ComputationType computation_type, blas::AlgorithmType algorithm,   \
2059       blas::ProfileResult *output_profile_result) override;                    \
2060   bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2061                   uint64 m, uint64 n, std::complex<float> alpha,               \
2062                   const DeviceMemory<std::complex<float>> &a, int lda,         \
2063                   const DeviceMemory<std::complex<float>> &b, int ldb,         \
2064                   std::complex<float> beta,                                    \
2065                   DeviceMemory<std::complex<float>> *c, int ldc) override;     \
2066   bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2067                   uint64 m, uint64 n, std::complex<double> alpha,              \
2068                   const DeviceMemory<std::complex<double>> &a, int lda,        \
2069                   const DeviceMemory<std::complex<double>> &b, int ldb,        \
2070                   std::complex<double> beta,                                   \
2071                   DeviceMemory<std::complex<double>> *c, int ldc) override;    \
2072   bool DoBlasHerk(Stream *stream, blas::UpperLower uplo,                       \
2073                   blas::Transpose trans, uint64 n, uint64 k, float alpha,      \
2074                   const DeviceMemory<std::complex<float>> &a, int lda,         \
2075                   float beta, DeviceMemory<std::complex<float>> *c, int ldc)   \
2076       override;                                                                \
2077   bool DoBlasHerk(Stream *stream, blas::UpperLower uplo,                       \
2078                   blas::Transpose trans, uint64 n, uint64 k, double alpha,     \
2079                   const DeviceMemory<std::complex<double>> &a, int lda,        \
2080                   double beta, DeviceMemory<std::complex<double>> *c, int ldc) \
2081       override;                                                                \
2082   bool DoBlasHer2k(                                                            \
2083       Stream *stream, blas::UpperLower uplo, blas::Transpose trans, uint64 n,  \
2084       uint64 k, std::complex<float> alpha,                                     \
2085       const DeviceMemory<std::complex<float>> &a, int lda,                     \
2086       const DeviceMemory<std::complex<float>> &b, int ldb, float beta,         \
2087       DeviceMemory<std::complex<float>> *c, int ldc) override;                 \
2088   bool DoBlasHer2k(                                                            \
2089       Stream *stream, blas::UpperLower uplo, blas::Transpose trans, uint64 n,  \
2090       uint64 k, std::complex<double> alpha,                                    \
2091       const DeviceMemory<std::complex<double>> &a, int lda,                    \
2092       const DeviceMemory<std::complex<double>> &b, int ldb, double beta,       \
2093       DeviceMemory<std::complex<double>> *c, int ldc) override;                \
2094   bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2095                   uint64 m, uint64 n, float alpha,                             \
2096                   const DeviceMemory<float> &a, int lda,                       \
2097                   const DeviceMemory<float> &b, int ldb, float beta,           \
2098                   DeviceMemory<float> *c, int ldc) override;                   \
2099   bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2100                   uint64 m, uint64 n, double alpha,                            \
2101                   const DeviceMemory<double> &a, int lda,                      \
2102                   const DeviceMemory<double> &b, int ldb, double beta,         \
2103                   DeviceMemory<double> *c, int ldc) override;                  \
2104   bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2105                   uint64 m, uint64 n, std::complex<float> alpha,               \
2106                   const DeviceMemory<std::complex<float>> &a, int lda,         \
2107                   const DeviceMemory<std::complex<float>> &b, int ldb,         \
2108                   std::complex<float> beta,                                    \
2109                   DeviceMemory<std::complex<float>> *c, int ldc) override;     \
2110   bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2111                   uint64 m, uint64 n, std::complex<double> alpha,              \
2112                   const DeviceMemory<std::complex<double>> &a, int lda,        \
2113                   const DeviceMemory<std::complex<double>> &b, int ldb,        \
2114                   std::complex<double> beta,                                   \
2115                   DeviceMemory<std::complex<double>> *c, int ldc) override;    \
2116   bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,                       \
2117                   blas::Transpose trans, uint64 n, uint64 k, float alpha,      \
2118                   const DeviceMemory<float> &a, int lda, float beta,           \
2119                   DeviceMemory<float> *c, int ldc) override;                   \
2120   bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,                       \
2121                   blas::Transpose trans, uint64 n, uint64 k, double alpha,     \
2122                   const DeviceMemory<double> &a, int lda, double beta,         \
2123                   DeviceMemory<double> *c, int ldc) override;                  \
2124   bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,                       \
2125                   blas::Transpose trans, uint64 n, uint64 k,                   \
2126                   std::complex<float> alpha,                                   \
2127                   const DeviceMemory<std::complex<float>> &a, int lda,         \
2128                   std::complex<float> beta,                                    \
2129                   DeviceMemory<std::complex<float>> *c, int ldc) override;     \
2130   bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,                       \
2131                   blas::Transpose trans, uint64 n, uint64 k,                   \
2132                   std::complex<double> alpha,                                  \
2133                   const DeviceMemory<std::complex<double>> &a, int lda,        \
2134                   std::complex<double> beta,                                   \
2135                   DeviceMemory<std::complex<double>> *c, int ldc) override;    \
2136   bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,                      \
2137                    blas::Transpose trans, uint64 n, uint64 k, float alpha,     \
2138                    const DeviceMemory<float> &a, int lda,                      \
2139                    const DeviceMemory<float> &b, int ldb, float beta,          \
2140                    DeviceMemory<float> *c, int ldc) override;                  \
2141   bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,                      \
2142                    blas::Transpose trans, uint64 n, uint64 k, double alpha,    \
2143                    const DeviceMemory<double> &a, int lda,                     \
2144                    const DeviceMemory<double> &b, int ldb, double beta,        \
2145                    DeviceMemory<double> *c, int ldc) override;                 \
2146   bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,                      \
2147                    blas::Transpose trans, uint64 n, uint64 k,                  \
2148                    std::complex<float> alpha,                                  \
2149                    const DeviceMemory<std::complex<float>> &a, int lda,        \
2150                    const DeviceMemory<std::complex<float>> &b, int ldb,        \
2151                    std::complex<float> beta,                                   \
2152                    DeviceMemory<std::complex<float>> *c, int ldc) override;    \
2153   bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,                      \
2154                    blas::Transpose trans, uint64 n, uint64 k,                  \
2155                    std::complex<double> alpha,                                 \
2156                    const DeviceMemory<std::complex<double>> &a, int lda,       \
2157                    const DeviceMemory<std::complex<double>> &b, int ldb,       \
2158                    std::complex<double> beta,                                  \
2159                    DeviceMemory<std::complex<double>> *c, int ldc) override;   \
2160   bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2161                   blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
2162                   uint64 n, float alpha, const DeviceMemory<float> &a,         \
2163                   int lda, DeviceMemory<float> *b, int ldb) override;          \
2164   bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2165                   blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
2166                   uint64 n, double alpha, const DeviceMemory<double> &a,       \
2167                   int lda, DeviceMemory<double> *b, int ldb) override;         \
2168   bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2169                   blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
2170                   uint64 n, std::complex<float> alpha,                         \
2171                   const DeviceMemory<std::complex<float>> &a, int lda,         \
2172                   DeviceMemory<std::complex<float>> *b, int ldb) override;     \
2173   bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2174                   blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
2175                   uint64 n, std::complex<double> alpha,                        \
2176                   const DeviceMemory<std::complex<double>> &a, int lda,        \
2177                   DeviceMemory<std::complex<double>> *b, int ldb) override;    \
2178   bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2179                   blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
2180                   uint64 n, float alpha, const DeviceMemory<float> &a,         \
2181                   int lda, DeviceMemory<float> *b, int ldb) override;          \
2182   bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2183                   blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
2184                   uint64 n, double alpha, const DeviceMemory<double> &a,       \
2185                   int lda, DeviceMemory<double> *b, int ldb) override;         \
2186   bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2187                   blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
2188                   uint64 n, std::complex<float> alpha,                         \
2189                   const DeviceMemory<std::complex<float>> &a, int lda,         \
2190                   DeviceMemory<std::complex<float>> *b, int ldb) override;     \
2191   bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2192                   blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
2193                   uint64 n, std::complex<double> alpha,                        \
2194                   const DeviceMemory<std::complex<double>> &a, int lda,        \
2195                   DeviceMemory<std::complex<double>> *b, int ldb) override;    \
2196   port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>                     \
2197   CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams &params) override; \
2198   port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>   \
2199   GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,               \
2200                             size_t max_workspace_size,                         \
2201                             int max_algorithm_count) override;                 \
2202   bool DoBlasLtMatmul(                                                         \
2203       Stream *stream, const blas::IBlasLtMatmulPlan *plan,                     \
2204       const HostOrDeviceScalar<void> &alpha, DeviceMemoryBase a,               \
2205       DeviceMemoryBase b, const HostOrDeviceScalar<void> &beta,                \
2206       DeviceMemoryBase c, ScratchAllocator *scratch_allocator,                 \
2207       const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias,    \
2208       blas::ProfileResult *output_profile_result) override;                    \
2209   port::Status GetVersion(std::string *version) override;
2210 
2211 }  // namespace blas
2212 }  // namespace stream_executor
2213 
2214 #endif  // TENSORFLOW_STREAM_EXECUTOR_BLAS_H_
2215