• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Exposes the family of BLAS routines as pre-canned high performance calls for
17 // use in conjunction with the StreamExecutor abstraction.
18 //
19 // Note that this interface is optionally supported by platforms; see
20 // StreamExecutor::SupportsBlas() for details.
21 //
22 // This abstraction makes it simple to entrain BLAS operations on GPU data into
23 // a Stream -- users typically will not use this API directly, but will use the
24 // Stream builder methods to entrain these operations "under the hood". For
25 // example:
26 //
27 //  DeviceMemory<float> x = stream_exec->AllocateArray<float>(1024);
28 //  DeviceMemory<float> y = stream_exec->AllocateArray<float>(1024);
29 //  // ... populate x and y ...
30 //  Stream stream{stream_exec};
31 //  stream
32 //    .Init()
33 //    .ThenBlasAxpy(1024, 5.5, x, 1, &y, 1);
34 //  SE_CHECK_OK(stream.BlockHostUntilDone());
35 //
36 // By using stream operations in this manner the user can easily intermix custom
37 // kernel launches (via StreamExecutor::ThenLaunch()) with these pre-canned BLAS
38 // routines.
39 
40 #ifndef TENSORFLOW_STREAM_EXECUTOR_BLAS_H_
41 #define TENSORFLOW_STREAM_EXECUTOR_BLAS_H_
42 
43 #include <complex>
44 #include <vector>
45 
46 #include "tensorflow/stream_executor/host_or_device_scalar.h"
47 #include "tensorflow/stream_executor/lib/array_slice.h"
48 #include "tensorflow/stream_executor/lib/statusor.h"
49 #include "tensorflow/stream_executor/platform/port.h"
50 
51 namespace Eigen {
52 struct half;
53 }  // namespace Eigen
54 
55 namespace stream_executor {
56 
57 class Stream;
58 class ScratchAllocator;
59 
60 template <typename ElemT>
61 class DeviceMemory;
62 
63 namespace blas {
64 
65 // Specifies whether the input matrix will be transposed or
66 // transposed+conjugated before any BLAS operations.
67 enum class Transpose { kNoTranspose, kTranspose, kConjugateTranspose };
68 
69 // Returns a name for t.
70 string TransposeString(Transpose t);
71 
72 // Specifies whether the upper or lower triangular part of a
73 // symmetric/Hermitian matrix is used.
74 enum class UpperLower { kUpper, kLower };
75 
76 // Returns a name for ul.
77 string UpperLowerString(UpperLower ul);
78 
79 // Specifies whether a matrix is unit triangular.
80 enum class Diagonal { kUnit, kNonUnit };
81 
82 // Returns a name for d.
83 string DiagonalString(Diagonal d);
84 
85 // Specifies whether a Hermitian matrix appears on the left or right in
86 // operation.
87 enum class Side { kLeft, kRight };
88 
89 // Returns a name for s.
90 string SideString(Side s);
91 
92 // Type with which intermediate computations of a blas routine are performed.
93 //
94 // Some blas calls can perform computations with a type that's different than
95 // the type of their inputs/outputs.  This lets you e.g. multiply two matrices
96 // of int8s using float32s to store the matmul's intermediate values.
97 enum class ComputationType {
98   kF16,         // 16-bit floating-point
99   kF32,         // 32-bit floating-point
100   kF64,         // 64-bit floating-point
101   kI32,         // 32-bit integer
102   kComplexF32,  // Complex number comprised of two f32s.
103   kComplexF64,  // Complex number comprised of two f64s.
104 };
105 
106 // Converts a ComputationType to a string.
107 string ComputationTypeString(ComputationType ty);
108 
109 std::ostream &operator<<(std::ostream &os, ComputationType ty);
110 
111 // Opaque identifier for an "algorithm" used by a blas routine.  This functions
112 // as a hint to the blas library.
113 typedef int64 AlgorithmType;
114 constexpr AlgorithmType kDefaultAlgorithm = -1;
115 constexpr AlgorithmType kDefaultBlasGemm = -2;
116 constexpr AlgorithmType kDefaultBlasGemv = -3;
117 constexpr AlgorithmType kNoAlgorithm = -4;
118 
119 // blas uses -1 to represent the default algorithm. This happens to match up
120 // with the CUBLAS_GEMM_DFALT constant, so cuda_blas.cc is using static_cast
121 // to convert from AlgorithmType to cublasGemmAlgo_t, and uses a static_assert
122 // to ensure that this assumption does not break.
123 // If another blas implementation uses a different value for the default
124 // algorithm, then it needs to convert kDefaultGemmAlgo to that value
125 // (e.g. via a function called ToWhateverGemmAlgo).
126 constexpr AlgorithmType kDefaultGemmAlgo = -1;
127 
128 // Describes the result of a performance experiment, usually timing the speed of
129 // a particular AlgorithmType.
130 //
131 // If the call we were benchmarking failed (a common occurrence; not all
132 // algorithms are valid for all calls), is_valid() will be false.
133 class ProfileResult {
134  public:
is_valid()135   bool is_valid() const { return is_valid_; }
set_is_valid(bool val)136   void set_is_valid(bool val) { is_valid_ = val; }
algorithm()137   AlgorithmType algorithm() const { return algorithm_; }
set_algorithm(AlgorithmType val)138   void set_algorithm(AlgorithmType val) { algorithm_ = val; }
elapsed_time_in_ms()139   float elapsed_time_in_ms() const { return elapsed_time_in_ms_; }
set_elapsed_time_in_ms(float val)140   void set_elapsed_time_in_ms(float val) { elapsed_time_in_ms_ = val; }
141 
142  private:
143   bool is_valid_ = false;
144   AlgorithmType algorithm_ = kDefaultAlgorithm;
145   float elapsed_time_in_ms_ = std::numeric_limits<float>::max();
146 };
147 
148 class AlgorithmConfig {
149  public:
AlgorithmConfig()150   AlgorithmConfig() : algorithm_(kDefaultAlgorithm) {}
AlgorithmConfig(AlgorithmType algorithm)151   explicit AlgorithmConfig(AlgorithmType algorithm) : algorithm_(algorithm) {}
algorithm()152   AlgorithmType algorithm() const { return algorithm_; }
set_algorithm(AlgorithmType val)153   void set_algorithm(AlgorithmType val) { algorithm_ = val; }
154   bool operator==(const AlgorithmConfig &other) const {
155     return this->algorithm_ == other.algorithm_;
156   }
157   bool operator!=(const AlgorithmConfig &other) const {
158     return !(*this == other);
159   }
160   string ToString() const;
161 
162  private:
163   AlgorithmType algorithm_;
164 };
165 
166 // BLAS support interface -- this can be derived from a GPU executor when the
167 // underlying platform has an BLAS library implementation available. See
168 // StreamExecutor::AsBlas().
169 //
170 // Thread-hostile: CUDA associates a CUDA-context with a particular thread in
171 // the system. Any operation that a user attempts to perform by enqueueing BLAS
172 // operations on a thread not-associated with the CUDA-context has unknown
173 // behavior at the current time; see b/13176597
174 class BlasSupport {
175  public:
~BlasSupport()176   virtual ~BlasSupport() {}
177 
178   // Computes the sum of magnitudes of the vector elements.
179   // result <- |Re x(1)| + |Im x(1)| + |Re  x(2)| + |Im  x(2)|+ ... + |Re  x(n)|
180   // + |Im x(n)|.
181   // Note that Im x(i) = 0 for real types float/double.
182   virtual bool DoBlasAsum(Stream *stream, uint64 elem_count,
183                           const DeviceMemory<float> &x, int incx,
184                           DeviceMemory<float> *result) = 0;
185   virtual bool DoBlasAsum(Stream *stream, uint64 elem_count,
186                           const DeviceMemory<double> &x, int incx,
187                           DeviceMemory<double> *result) = 0;
188   virtual bool DoBlasAsum(Stream *stream, uint64 elem_count,
189                           const DeviceMemory<std::complex<float>> &x, int incx,
190                           DeviceMemory<float> *result) = 0;
191   virtual bool DoBlasAsum(Stream *stream, uint64 elem_count,
192                           const DeviceMemory<std::complex<double>> &x, int incx,
193                           DeviceMemory<double> *result) = 0;
194 
195   // Performs a BLAS y <- ax+y operation.
196   virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha,
197                           const DeviceMemory<float> &x, int incx,
198                           DeviceMemory<float> *y, int incy) = 0;
199   virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha,
200                           const DeviceMemory<double> &x, int incx,
201                           DeviceMemory<double> *y, int incy) = 0;
202   virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count,
203                           std::complex<float> alpha,
204                           const DeviceMemory<std::complex<float>> &x, int incx,
205                           DeviceMemory<std::complex<float>> *y, int incy) = 0;
206   virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count,
207                           std::complex<double> alpha,
208                           const DeviceMemory<std::complex<double>> &x, int incx,
209                           DeviceMemory<std::complex<double>> *y, int incy) = 0;
210 
211   // Copies vector to another vector: y <- x.
212   virtual bool DoBlasCopy(Stream *stream, uint64 elem_count,
213                           const DeviceMemory<float> &x, int incx,
214                           DeviceMemory<float> *y, int incy) = 0;
215   virtual bool DoBlasCopy(Stream *stream, uint64 elem_count,
216                           const DeviceMemory<double> &x, int incx,
217                           DeviceMemory<double> *y, int incy) = 0;
218   virtual bool DoBlasCopy(Stream *stream, uint64 elem_count,
219                           const DeviceMemory<std::complex<float>> &x, int incx,
220                           DeviceMemory<std::complex<float>> *y, int incy) = 0;
221   virtual bool DoBlasCopy(Stream *stream, uint64 elem_count,
222                           const DeviceMemory<std::complex<double>> &x, int incx,
223                           DeviceMemory<std::complex<double>> *y, int incy) = 0;
224 
225   // Performs a BLAS dot product result <- x . y.
226   virtual bool DoBlasDot(Stream *stream, uint64 elem_count,
227                          const DeviceMemory<float> &x, int incx,
228                          const DeviceMemory<float> &y, int incy,
229                          DeviceMemory<float> *result) = 0;
230   virtual bool DoBlasDot(Stream *stream, uint64 elem_count,
231                          const DeviceMemory<double> &x, int incx,
232                          const DeviceMemory<double> &y, int incy,
233                          DeviceMemory<double> *result) = 0;
234 
235   // Performs a BLAS dot product result <- conj(x) . y for complex types.
236   virtual bool DoBlasDotc(Stream *stream, uint64 elem_count,
237                           const DeviceMemory<std::complex<float>> &x, int incx,
238                           const DeviceMemory<std::complex<float>> &y, int incy,
239                           DeviceMemory<std::complex<float>> *result) = 0;
240   virtual bool DoBlasDotc(Stream *stream, uint64 elem_count,
241                           const DeviceMemory<std::complex<double>> &x, int incx,
242                           const DeviceMemory<std::complex<double>> &y, int incy,
243                           DeviceMemory<std::complex<double>> *result) = 0;
244 
245   // Performs a BLAS dot product result <- x . y for complex types. Note that
246   // x is unconjugated in this routine.
247   virtual bool DoBlasDotu(Stream *stream, uint64 elem_count,
248                           const DeviceMemory<std::complex<float>> &x, int incx,
249                           const DeviceMemory<std::complex<float>> &y, int incy,
250                           DeviceMemory<std::complex<float>> *result) = 0;
251   virtual bool DoBlasDotu(Stream *stream, uint64 elem_count,
252                           const DeviceMemory<std::complex<double>> &x, int incx,
253                           const DeviceMemory<std::complex<double>> &y, int incy,
254                           DeviceMemory<std::complex<double>> *result) = 0;
255 
256   // Computes the Euclidean norm of a vector: result <- ||x||.
257   // See the following link for more information of Euclidean norm:
258   // http://en.wikipedia.org/wiki/Norm_(mathematics)#Euclidean_norm
259   virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count,
260                           const DeviceMemory<float> &x, int incx,
261                           DeviceMemory<float> *result) = 0;
262   virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count,
263                           const DeviceMemory<double> &x, int incx,
264                           DeviceMemory<double> *result) = 0;
265   virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count,
266                           const DeviceMemory<std::complex<float>> &x, int incx,
267                           DeviceMemory<float> *result) = 0;
268   virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count,
269                           const DeviceMemory<std::complex<double>> &x, int incx,
270                           DeviceMemory<double> *result) = 0;
271 
272   // Performs rotation of points in the plane:
273   // x(i) = c*x(i) + s*y(i)
274   // y(i) = c*y(i) - s*x(i).
275   virtual bool DoBlasRot(Stream *stream, uint64 elem_count,
276                          DeviceMemory<float> *x, int incx,
277                          DeviceMemory<float> *y, int incy, float c,
278                          float s) = 0;
279   virtual bool DoBlasRot(Stream *stream, uint64 elem_count,
280                          DeviceMemory<double> *x, int incx,
281                          DeviceMemory<double> *y, int incy, double c,
282                          double s) = 0;
283   virtual bool DoBlasRot(Stream *stream, uint64 elem_count,
284                          DeviceMemory<std::complex<float>> *x, int incx,
285                          DeviceMemory<std::complex<float>> *y, int incy,
286                          float c, float s) = 0;
287   virtual bool DoBlasRot(Stream *stream, uint64 elem_count,
288                          DeviceMemory<std::complex<double>> *x, int incx,
289                          DeviceMemory<std::complex<double>> *y, int incy,
290                          double c, double s) = 0;
291 
292   // Computes the parameters for a Givens rotation.
293   // Given the Cartesian coordinates (a, b) of a point, these routines return
294   // the parameters c, s, r, and z associated with the Givens rotation. The
295   // parameters c and s define a unitary matrix such that:
296   //
297   //   |  c s |.| a | = | r |
298   //   | -s c | | b |   | 0 |
299   //
300   // The parameter z is defined such that if |a| > |b|, z is s; otherwise if
301   // c is not 0 z is 1/c; otherwise z is 1.
302   virtual bool DoBlasRotg(Stream *stream, DeviceMemory<float> *a,
303                           DeviceMemory<float> *b, DeviceMemory<float> *c,
304                           DeviceMemory<float> *s) = 0;
305   virtual bool DoBlasRotg(Stream *stream, DeviceMemory<double> *a,
306                           DeviceMemory<double> *b, DeviceMemory<double> *c,
307                           DeviceMemory<double> *s) = 0;
308   virtual bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<float>> *a,
309                           DeviceMemory<std::complex<float>> *b,
310                           DeviceMemory<float> *c,
311                           DeviceMemory<std::complex<float>> *s) = 0;
312   virtual bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<double>> *a,
313                           DeviceMemory<std::complex<double>> *b,
314                           DeviceMemory<double> *c,
315                           DeviceMemory<std::complex<double>> *s) = 0;
316 
317   // Performs modified Givens rotation of points in the plane.
318   // Given two vectors x and y, each vector element of these vectors is replaced
319   // as follows:
320   //
321   //   | x(i) | =  H | x(i) |
322   //   | y(i) |      | y(i) |
323   //
324   // for i=1 to n, where H is a modified Givens transformation matrix whose
325   // values are stored in the param[1] through param[4] array.
326   // For more information please Google this routine.
327   virtual bool DoBlasRotm(Stream *stream, uint64 elem_count,
328                           DeviceMemory<float> *x, int incx,
329                           DeviceMemory<float> *y, int incy,
330                           const DeviceMemory<float> &param) = 0;
331   virtual bool DoBlasRotm(Stream *stream, uint64 elem_count,
332                           DeviceMemory<double> *x, int incx,
333                           DeviceMemory<double> *y, int incy,
334                           const DeviceMemory<double> &param) = 0;
335 
336   // Computes the parameters for a modified Givens rotation.
337   // Given Cartesian coordinates (x1, y1) of an input vector, these routines
338   // compute the components of a modified Givens transformation matrix H that
339   // zeros the y-component of the resulting vector:
340   //
341   //   | x1 | =  H | x1 * sqrt(d1) |
342   //   |  0 |      | y1 * sqrt(d1) |
343   //
344   // For more information please Google this routine.
345   virtual bool DoBlasRotmg(Stream *stream, DeviceMemory<float> *d1,
346                            DeviceMemory<float> *d2, DeviceMemory<float> *x1,
347                            const DeviceMemory<float> &y1,
348                            DeviceMemory<float> *param) = 0;
349   virtual bool DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1,
350                            DeviceMemory<double> *d2, DeviceMemory<double> *x1,
351                            const DeviceMemory<double> &y1,
352                            DeviceMemory<double> *param) = 0;
353 
354   // Computes the product of a vector by a scalar: x <- a*x.
355   virtual bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
356                           DeviceMemory<float> *x, int incx) = 0;
357   virtual bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
358                           DeviceMemory<double> *x, int incx) = 0;
359   virtual bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
360                           DeviceMemory<std::complex<float>> *x, int incx) = 0;
361   virtual bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
362                           DeviceMemory<std::complex<double>> *x, int incx) = 0;
363   virtual bool DoBlasScal(Stream *stream, uint64 elem_count,
364                           std::complex<float> alpha,
365                           DeviceMemory<std::complex<float>> *x, int incx) = 0;
366   virtual bool DoBlasScal(Stream *stream, uint64 elem_count,
367                           std::complex<double> alpha,
368                           DeviceMemory<std::complex<double>> *x, int incx) = 0;
369 
370   // Swaps a vector with another vector.
371   virtual bool DoBlasSwap(Stream *stream, uint64 elem_count,
372                           DeviceMemory<float> *x, int incx,
373                           DeviceMemory<float> *y, int incy) = 0;
374   virtual bool DoBlasSwap(Stream *stream, uint64 elem_count,
375                           DeviceMemory<double> *x, int incx,
376                           DeviceMemory<double> *y, int incy) = 0;
377   virtual bool DoBlasSwap(Stream *stream, uint64 elem_count,
378                           DeviceMemory<std::complex<float>> *x, int incx,
379                           DeviceMemory<std::complex<float>> *y, int incy) = 0;
380   virtual bool DoBlasSwap(Stream *stream, uint64 elem_count,
381                           DeviceMemory<std::complex<double>> *x, int incx,
382                           DeviceMemory<std::complex<double>> *y, int incy) = 0;
383 
384   // Finds the index of the element with maximum absolute value.
385   virtual bool DoBlasIamax(Stream *stream, uint64 elem_count,
386                            const DeviceMemory<float> &x, int incx,
387                            DeviceMemory<int> *result) = 0;
388   virtual bool DoBlasIamax(Stream *stream, uint64 elem_count,
389                            const DeviceMemory<double> &x, int incx,
390                            DeviceMemory<int> *result) = 0;
391   virtual bool DoBlasIamax(Stream *stream, uint64 elem_count,
392                            const DeviceMemory<std::complex<float>> &x, int incx,
393                            DeviceMemory<int> *result) = 0;
394   virtual bool DoBlasIamax(Stream *stream, uint64 elem_count,
395                            const DeviceMemory<std::complex<double>> &x,
396                            int incx, DeviceMemory<int> *result) = 0;
397 
398   // Finds the index of the element with minimum absolute value.
399   virtual bool DoBlasIamin(Stream *stream, uint64 elem_count,
400                            const DeviceMemory<float> &x, int incx,
401                            DeviceMemory<int> *result) = 0;
402   virtual bool DoBlasIamin(Stream *stream, uint64 elem_count,
403                            const DeviceMemory<double> &x, int incx,
404                            DeviceMemory<int> *result) = 0;
405   virtual bool DoBlasIamin(Stream *stream, uint64 elem_count,
406                            const DeviceMemory<std::complex<float>> &x, int incx,
407                            DeviceMemory<int> *result) = 0;
408   virtual bool DoBlasIamin(Stream *stream, uint64 elem_count,
409                            const DeviceMemory<std::complex<double>> &x,
410                            int incx, DeviceMemory<int> *result) = 0;
411 
412   // Computes a matrix-vector product using a general band matrix:
413   //
414   //     y <- alpha * a * x + beta * y,
415   // or
416   //     y <- alpha * a' * x + beta * y,
417   // or
418   //     y <- alpha * conj(a') * x + beta * y,
419   //
420   // alpha and beta are scalars; a is an m-by-n general band matrix, with kl
421   // sub-diagonals and ku super-diagonals; x is a vector with
422   // n(trans==kNoTranspose)/m(otherwise) elements;
423   // y is a vector with m(trans==kNoTranspose)/n(otherwise) elements.
424   virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
425                           uint64 n, uint64 kl, uint64 ku, float alpha,
426                           const DeviceMemory<float> &a, int lda,
427                           const DeviceMemory<float> &x, int incx, float beta,
428                           DeviceMemory<float> *y, int incy) = 0;
429   virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
430                           uint64 n, uint64 kl, uint64 ku, double alpha,
431                           const DeviceMemory<double> &a, int lda,
432                           const DeviceMemory<double> &x, int incx, double beta,
433                           DeviceMemory<double> *y, int incy) = 0;
434   virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
435                           uint64 n, uint64 kl, uint64 ku,
436                           std::complex<float> alpha,
437                           const DeviceMemory<std::complex<float>> &a, int lda,
438                           const DeviceMemory<std::complex<float>> &x, int incx,
439                           std::complex<float> beta,
440                           DeviceMemory<std::complex<float>> *y, int incy) = 0;
441   virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
442                           uint64 n, uint64 kl, uint64 ku,
443                           std::complex<double> alpha,
444                           const DeviceMemory<std::complex<double>> &a, int lda,
445                           const DeviceMemory<std::complex<double>> &x, int incx,
446                           std::complex<double> beta,
447                           DeviceMemory<std::complex<double>> *y, int incy) = 0;
448 
449   // Computes a matrix-vector product using a general matrix.
450   //
451   //     y <- alpha * a * x + beta * y,
452   // or
453   //     y <- alpha * a' * x + beta * y,
454   // or
455   //     y <- alpha * conj(a') * x + beta * y,
456   //
457   // alpha and beta are scalars; a is an m-by-n general matrix; x is a vector
458   // with n(trans==kNoTranspose)/m(otherwise) elements;
459   // y is a vector with m(trans==kNoTranspose)/n(otherwise) elements.
460   virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
461                           uint64 n, float alpha, const DeviceMemory<float> &a,
462                           int lda, const DeviceMemory<float> &x, int incx,
463                           float beta, DeviceMemory<float> *y, int incy) = 0;
464   virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
465                           uint64 n, double alpha, const DeviceMemory<double> &a,
466                           int lda, const DeviceMemory<double> &x, int incx,
467                           double beta, DeviceMemory<double> *y, int incy) = 0;
468   virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
469                           uint64 n, std::complex<float> alpha,
470                           const DeviceMemory<std::complex<float>> &a, int lda,
471                           const DeviceMemory<std::complex<float>> &x, int incx,
472                           std::complex<float> beta,
473                           DeviceMemory<std::complex<float>> *y, int incy) = 0;
474   virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
475                           uint64 n, std::complex<double> alpha,
476                           const DeviceMemory<std::complex<double>> &a, int lda,
477                           const DeviceMemory<std::complex<double>> &x, int incx,
478                           std::complex<double> beta,
479                           DeviceMemory<std::complex<double>> *y, int incy) = 0;
480 
481   virtual bool DoBlasGemvWithProfiling(
482       Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha,
483       const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
484       int incx, float beta, DeviceMemory<float> *y, int incy,
485       ProfileResult *output_profile_result) = 0;
486   virtual bool DoBlasGemvWithProfiling(
487       Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha,
488       const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
489       int incx, double beta, DeviceMemory<double> *y, int incy,
490       ProfileResult *output_profile_result) = 0;
491   virtual bool DoBlasGemvWithProfiling(
492       Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
493       std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a,
494       int lda, const DeviceMemory<std::complex<float>> &x, int incx,
495       std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
496       ProfileResult *output_profile_result) = 0;
497   virtual bool DoBlasGemvWithProfiling(
498       Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
499       std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a,
500       int lda, const DeviceMemory<std::complex<double>> &x, int incx,
501       std::complex<double> beta, DeviceMemory<std::complex<double>> *y,
502       int incy, ProfileResult *output_profile_result) = 0;
503 
504   // Performs a rank-1 update of a general matrix.
505   //
506   //     a <- alpha * x * y' + a,
507   //
508   // alpha is a scalar; x is an m-element vector; y is an n-element vector; a is
509   // an m-by-n general matrix.
510   virtual bool DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha,
511                          const DeviceMemory<float> &x, int incx,
512                          const DeviceMemory<float> &y, int incy,
513                          DeviceMemory<float> *a, int lda) = 0;
514   virtual bool DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha,
515                          const DeviceMemory<double> &x, int incx,
516                          const DeviceMemory<double> &y, int incy,
517                          DeviceMemory<double> *a, int lda) = 0;
518 
519   // Performs a rank-1 update (conjugated) of a general matrix.
520   //
521   //     a <- alpha * x * conj(y') + a,
522   //
523   // alpha is a scalar; x is an m-element vector; y is an n-element vector; a is
524   // an m-by-n general matrix.
525   virtual bool DoBlasGerc(Stream *stream, uint64 m, uint64 n,
526                           std::complex<float> alpha,
527                           const DeviceMemory<std::complex<float>> &x, int incx,
528                           const DeviceMemory<std::complex<float>> &y, int incy,
529                           DeviceMemory<std::complex<float>> *a, int lda) = 0;
530   virtual bool DoBlasGerc(Stream *stream, uint64 m, uint64 n,
531                           std::complex<double> alpha,
532                           const DeviceMemory<std::complex<double>> &x, int incx,
533                           const DeviceMemory<std::complex<double>> &y, int incy,
534                           DeviceMemory<std::complex<double>> *a, int lda) = 0;
535 
536   // Performs a rank-1 update (unconjugated) of a general matrix.
537   //
538   //     a <- alpha * x * y' + a,
539   //
540   // alpha is a scalar; x is an m-element vector; y is an n-element vector; a is
541   // an m-by-n general matrix.
542   virtual bool DoBlasGeru(Stream *stream, uint64 m, uint64 n,
543                           std::complex<float> alpha,
544                           const DeviceMemory<std::complex<float>> &x, int incx,
545                           const DeviceMemory<std::complex<float>> &y, int incy,
546                           DeviceMemory<std::complex<float>> *a, int lda) = 0;
547   virtual bool DoBlasGeru(Stream *stream, uint64 m, uint64 n,
548                           std::complex<double> alpha,
549                           const DeviceMemory<std::complex<double>> &x, int incx,
550                           const DeviceMemory<std::complex<double>> &y, int incy,
551                           DeviceMemory<std::complex<double>> *a, int lda) = 0;
552 
553   // Computes a matrix-vector product using a Hermitian band matrix.
554   //
555   //     y <- alpha * a * x + beta * y,
556   //
557   // alpha and beta are scalars; a is an n-by-n Hermitian band matrix, with k
558   // super-diagonals; x and y are n-element vectors.
559   virtual bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
560                           uint64 k, std::complex<float> alpha,
561                           const DeviceMemory<std::complex<float>> &a, int lda,
562                           const DeviceMemory<std::complex<float>> &x, int incx,
563                           std::complex<float> beta,
564                           DeviceMemory<std::complex<float>> *y, int incy) = 0;
565   virtual bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
566                           uint64 k, std::complex<double> alpha,
567                           const DeviceMemory<std::complex<double>> &a, int lda,
568                           const DeviceMemory<std::complex<double>> &x, int incx,
569                           std::complex<double> beta,
570                           DeviceMemory<std::complex<double>> *y, int incy) = 0;
571 
572   // Computes a matrix-vector product using a Hermitian matrix.
573   //
574   //     y <- alpha * a * x + beta * y,
575   //
576   // alpha and beta are scalars; a is an n-by-n Hermitian matrix; x and y are
577   // n-element vectors.
578   virtual bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
579                           std::complex<float> alpha,
580                           const DeviceMemory<std::complex<float>> &a, int lda,
581                           const DeviceMemory<std::complex<float>> &x, int incx,
582                           std::complex<float> beta,
583                           DeviceMemory<std::complex<float>> *y, int incy) = 0;
584   virtual bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
585                           std::complex<double> alpha,
586                           const DeviceMemory<std::complex<double>> &a, int lda,
587                           const DeviceMemory<std::complex<double>> &x, int incx,
588                           std::complex<double> beta,
589                           DeviceMemory<std::complex<double>> *y, int incy) = 0;
590 
591   // Performs a rank-1 update of a Hermitian matrix.
592   //
593   //     a <- alpha * x * conj(x') + a,
594   //
595   // alpha is a scalar; x is an n-element vector; a is an n-by-n Hermitian
596   // matrix.
597   virtual bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
598                          float alpha,
599                          const DeviceMemory<std::complex<float>> &x, int incx,
600                          DeviceMemory<std::complex<float>> *a, int lda) = 0;
601   virtual bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
602                          double alpha,
603                          const DeviceMemory<std::complex<double>> &x, int incx,
604                          DeviceMemory<std::complex<double>> *a, int lda) = 0;
605 
606   // Performs a rank-2 update of a Hermitian matrix.
607   //
608   //     a <- alpha * x * conj(x') + conj(alpha) * y * conj(x') + a,
609   //
610   // alpha is a scalar; x and y are n-element vectors; a is an n-by-n Hermitian
611   // matrix.
612   virtual bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
613                           std::complex<float> alpha,
614                           const DeviceMemory<std::complex<float>> &x, int incx,
615                           const DeviceMemory<std::complex<float>> &y, int incy,
616                           DeviceMemory<std::complex<float>> *a, int lda) = 0;
617   virtual bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
618                           std::complex<double> alpha,
619                           const DeviceMemory<std::complex<double>> &x, int incx,
620                           const DeviceMemory<std::complex<double>> &y, int incy,
621                           DeviceMemory<std::complex<double>> *a, int lda) = 0;
622 
623   // Computes a matrix-vector product using a Hermitian packed matrix.
624   //
625   //     y <- alpha * a * x + beta * y,
626   //
627   // alpha and beta are scalars; a is an n-by-n Hermitian matrix, supplied in
628   // packed form; x and y are n-element vectors.
629   virtual bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
630                           std::complex<float> alpha,
631                           const DeviceMemory<std::complex<float>> &ap,
632                           const DeviceMemory<std::complex<float>> &x, int incx,
633                           std::complex<float> beta,
634                           DeviceMemory<std::complex<float>> *y, int incy) = 0;
635   virtual bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
636                           std::complex<double> alpha,
637                           const DeviceMemory<std::complex<double>> &ap,
638                           const DeviceMemory<std::complex<double>> &x, int incx,
639                           std::complex<double> beta,
640                           DeviceMemory<std::complex<double>> *y, int incy) = 0;
641 
642   // Performs a rank-1 update of a Hermitian packed matrix.
643   //
644   //     a <- alpha * x * conj(x') + a,
645   //
646   // alpha is a scalar; x is an n-element vector; a is an n-by-n Hermitian
647   // matrix, supplied in packed form.
648   virtual bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
649                          float alpha,
650                          const DeviceMemory<std::complex<float>> &x, int incx,
651                          DeviceMemory<std::complex<float>> *ap) = 0;
652   virtual bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
653                          double alpha,
654                          const DeviceMemory<std::complex<double>> &x, int incx,
655                          DeviceMemory<std::complex<double>> *ap) = 0;
656 
657   // Performs a rank-2 update of a Hermitian packed matrix.
658   //
659   //     a <- alpha * x * conj(x') + conj(alpha) * y * conj(x') + a,
660   //
661   // alpha is a scalar; x and y are n-element vectors; a is an n-by-n Hermitian
662   // matrix, supplied in packed form.
663   virtual bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
664                           std::complex<float> alpha,
665                           const DeviceMemory<std::complex<float>> &x, int incx,
666                           const DeviceMemory<std::complex<float>> &y, int incy,
667                           DeviceMemory<std::complex<float>> *ap) = 0;
668   virtual bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
669                           std::complex<double> alpha,
670                           const DeviceMemory<std::complex<double>> &x, int incx,
671                           const DeviceMemory<std::complex<double>> &y, int incy,
672                           DeviceMemory<std::complex<double>> *ap) = 0;
673 
674   // Computes a matrix-vector product using a symmetric band matrix.
675   //
676   //     y <- alpha * a * x + beta * y,
677   //
678   // alpha and beta are scalars; a is an n-by-n symmetric band matrix, with k
679   // super-diagonals; x and y are n-element vectors.
680   virtual bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
681                           uint64 k, float alpha, const DeviceMemory<float> &a,
682                           int lda, const DeviceMemory<float> &x, int incx,
683                           float beta, DeviceMemory<float> *y, int incy) = 0;
684   virtual bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
685                           uint64 k, double alpha, const DeviceMemory<double> &a,
686                           int lda, const DeviceMemory<double> &x, int incx,
687                           double beta, DeviceMemory<double> *y, int incy) = 0;
688 
689   // Computes a matrix-vector product using a symmetric packed matrix.
690   //
691   //     y <- alpha * a * x + beta * y,
692   //
693   // alpha and beta are scalars; a is an n-by-n symmetric matrix, supplied in
694   // packed form; x and y are n-element vectors.
695   virtual bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
696                           float alpha, const DeviceMemory<float> &ap,
697                           const DeviceMemory<float> &x, int incx, float beta,
698                           DeviceMemory<float> *y, int incy) = 0;
699   virtual bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
700                           double alpha, const DeviceMemory<double> &ap,
701                           const DeviceMemory<double> &x, int incx, double beta,
702                           DeviceMemory<double> *y, int incy) = 0;
703 
704   // Performs a rank-1 update of a symmetric packed matrix.
705   //
706   //     a <- alpha * x * x' + a,
707   //
708   // alpha is a scalar; x is an n-element vector; a is an n-by-n symmetric
709   // matrix, supplied in packed form.
710   virtual bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
711                          float alpha, const DeviceMemory<float> &x, int incx,
712                          DeviceMemory<float> *ap) = 0;
713   virtual bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
714                          double alpha, const DeviceMemory<double> &x, int incx,
715                          DeviceMemory<double> *ap) = 0;
716 
717   // Performs a rank-2 update of a symmetric packed matrix.
718   //
719   //     a <- alpha * x * x' + alpha * y * x' + a,
720   //
721   // alpha is a scalar; x and y are n-element vectors; a is an n-by-n symmetric
722   // matrix, supplied in packed form.
723   virtual bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
724                           float alpha, const DeviceMemory<float> &x, int incx,
725                           const DeviceMemory<float> &y, int incy,
726                           DeviceMemory<float> *ap) = 0;
727   virtual bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
728                           double alpha, const DeviceMemory<double> &x, int incx,
729                           const DeviceMemory<double> &y, int incy,
730                           DeviceMemory<double> *ap) = 0;
731 
732   // Computes a matrix-vector product for a symmetric matrix.
733   //
734   //     y <- alpha * a * x + beta * y,
735   //
736   // alpha and beta are scalars; a is an n-by-n symmetric matrix; x and y are
737   // n-element vectors.
738   virtual bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
739                           float alpha, const DeviceMemory<float> &a, int lda,
740                           const DeviceMemory<float> &x, int incx, float beta,
741                           DeviceMemory<float> *y, int incy) = 0;
742   virtual bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
743                           double alpha, const DeviceMemory<double> &a, int lda,
744                           const DeviceMemory<double> &x, int incx, double beta,
745                           DeviceMemory<double> *y, int incy) = 0;
746 
747   // Performs a rank-1 update of a symmetric matrix.
748   //
749   //     a <- alpha * x * x' + a,
750   //
751   // alpha is a scalar; x is an n-element vector; a is an n-by-n symmetric
752   // matrix.
753   virtual bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
754                          float alpha, const DeviceMemory<float> &x, int incx,
755                          DeviceMemory<float> *a, int lda) = 0;
756   virtual bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
757                          double alpha, const DeviceMemory<double> &x, int incx,
758                          DeviceMemory<double> *a, int lda) = 0;
759 
760   // Performs a rank-2 update of symmetric matrix.
761   //
762   //     a <- alpha * x * x' + alpha * y * x' + a,
763   //
764   // alpha is a scalar; x and y are n-element vectors; a is an n-by-n symmetric
765   // matrix.
766   virtual bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
767                           float alpha, const DeviceMemory<float> &x, int incx,
768                           const DeviceMemory<float> &y, int incy,
769                           DeviceMemory<float> *a, int lda) = 0;
770   virtual bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
771                           double alpha, const DeviceMemory<double> &x, int incx,
772                           const DeviceMemory<double> &y, int incy,
773                           DeviceMemory<double> *a, int lda) = 0;
774 
775   // Computes a matrix-vector product using a triangular band matrix.
776   //
777   //     x <- a * x,
778   // or
779   //     x <- a' * x,
780   // or
781   //     x <- conj(a') * x,
782   //
783   // a is an n-by-n unit, or non-unit, upper or lower triangular band matrix,
784   // with k+1 diagonals; x is a n-element vector.
785   virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
786                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
787                           uint64 k, const DeviceMemory<float> &a, int lda,
788                           DeviceMemory<float> *x, int incx) = 0;
789   virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
790                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
791                           uint64 k, const DeviceMemory<double> &a, int lda,
792                           DeviceMemory<double> *x, int incx) = 0;
793   virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
794                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
795                           uint64 k, const DeviceMemory<std::complex<float>> &a,
796                           int lda, DeviceMemory<std::complex<float>> *x,
797                           int incx) = 0;
798   virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
799                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
800                           uint64 k, const DeviceMemory<std::complex<double>> &a,
801                           int lda, DeviceMemory<std::complex<double>> *x,
802                           int incx) = 0;
803 
804   // Solves a system of linear equations whose coefficients are in a triangular
805   // band matrix as below:
806   //
807   //     a * x = b,
808   // or
809   //     a' * x = b,
810   // or
811   //     conj(a') * x = b,
812   //
813   // b and x are n-element vectors; a is an n-by-n unit, or non-unit, upper or
814   // lower triangular band matrix, with k+1 diagonals.
815   virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
816                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
817                           uint64 k, const DeviceMemory<float> &a, int lda,
818                           DeviceMemory<float> *x, int incx) = 0;
819   virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
820                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
821                           uint64 k, const DeviceMemory<double> &a, int lda,
822                           DeviceMemory<double> *x, int incx) = 0;
823   virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
824                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
825                           uint64 k, const DeviceMemory<std::complex<float>> &a,
826                           int lda, DeviceMemory<std::complex<float>> *x,
827                           int incx) = 0;
828   virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
829                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
830                           uint64 k, const DeviceMemory<std::complex<double>> &a,
831                           int lda, DeviceMemory<std::complex<double>> *x,
832                           int incx) = 0;
833 
834   // Computes a matrix-vector product using a triangular packed matrix.
835   //
836   //     x <- a * x,
837   // or
838   //     x <- a' * x,
839   // or
840   //     x <- conj(a') * x,
841   //
842   // a is an n-by-n unit, or non-unit, upper or lower triangular matrix,
843   // supplied in packed form; x is a n-element vector.
844   virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
845                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
846                           const DeviceMemory<float> &ap, DeviceMemory<float> *x,
847                           int incx) = 0;
848   virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
849                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
850                           const DeviceMemory<double> &ap,
851                           DeviceMemory<double> *x, int incx) = 0;
852   virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
853                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
854                           const DeviceMemory<std::complex<float>> &ap,
855                           DeviceMemory<std::complex<float>> *x, int incx) = 0;
856   virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
857                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
858                           const DeviceMemory<std::complex<double>> &ap,
859                           DeviceMemory<std::complex<double>> *x, int incx) = 0;
860 
861   // Solves a system of linear equations whose coefficients are in a triangular
862   // packed matrix as below:
863   //
864   //     a * x = b,
865   // or
866   //     a' * x = b,
867   // or
868   //     conj(a') * x = b,
869   //
870   // b and x are n-element vectors; a is an n-by-n unit, or non-unit, upper or
871   // lower triangular matrix, supplied in packed form.
872   virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
873                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
874                           const DeviceMemory<float> &ap, DeviceMemory<float> *x,
875                           int incx) = 0;
876   virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
877                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
878                           const DeviceMemory<double> &ap,
879                           DeviceMemory<double> *x, int incx) = 0;
880   virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
881                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
882                           const DeviceMemory<std::complex<float>> &ap,
883                           DeviceMemory<std::complex<float>> *x, int incx) = 0;
884   virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
885                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
886                           const DeviceMemory<std::complex<double>> &ap,
887                           DeviceMemory<std::complex<double>> *x, int incx) = 0;
888 
889   // Computes a matrix-vector product using a triangular matrix.
890   //
891   //     x <- a * x,
892   // or
893   //     x <- a' * x,
894   // or
895   //     x <- conj(a') * x,
896   //
897   // a is an n-by-n unit, or non-unit, upper or lower triangular matrix; x is a
898   // n-element vector.
899   virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
900                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
901                           const DeviceMemory<float> &a, int lda,
902                           DeviceMemory<float> *x, int incx) = 0;
903   virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
904                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
905                           const DeviceMemory<double> &a, int lda,
906                           DeviceMemory<double> *x, int incx) = 0;
907   virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
908                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
909                           const DeviceMemory<std::complex<float>> &a, int lda,
910                           DeviceMemory<std::complex<float>> *x, int incx) = 0;
911   virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
912                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
913                           const DeviceMemory<std::complex<double>> &a, int lda,
914                           DeviceMemory<std::complex<double>> *x, int incx) = 0;
915 
916   // Solves a system of linear equations whose coefficients are in a triangular
917   // matrix as below:
918   //
919   //     a * x = b,
920   // or
921   //     a' * x = b,
922   // or
923   //     conj(a') * x = b,
924   //
925   // b and x are n-element vectors; a is an n-by-n unit, or non-unit, upper or
926   // lower triangular matrix.
927   virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
928                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
929                           const DeviceMemory<float> &a, int lda,
930                           DeviceMemory<float> *x, int incx) = 0;
931   virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
932                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
933                           const DeviceMemory<double> &a, int lda,
934                           DeviceMemory<double> *x, int incx) = 0;
935   virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
936                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
937                           const DeviceMemory<std::complex<float>> &a, int lda,
938                           DeviceMemory<std::complex<float>> *x, int incx) = 0;
939   virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
940                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
941                           const DeviceMemory<std::complex<double>> &a, int lda,
942                           DeviceMemory<std::complex<double>> *x, int incx) = 0;
943 
944   // Computes a matrix-matrix product with general matrices:
945   //
946   //     c <- alpha * op(a) * op(b) + beta * c,
947   //
948   // op(X) is one of op(X) = X, or op(X) = X', or op(X) = conj(X'); alpha and
949   // beta are scalars; a, b, and c are matrices; op(a) is an m-by-k matrix;
950   // op(b) is a k-by-n matrix; c is an m-by-n matrix.
951   //
952   // Note: The half interface uses float precision internally; the version
953   // that uses half precision internally is not yet supported. There is no
954   // batched version of the half-precision interface.
955   virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa,
956                           blas::Transpose transb, uint64 m, uint64 n, uint64 k,
957                           float alpha, const DeviceMemory<Eigen::half> &a,
958                           int lda, const DeviceMemory<Eigen::half> &b, int ldb,
959                           float beta, DeviceMemory<Eigen::half> *c,
960                           int ldc) = 0;
961   virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa,
962                           blas::Transpose transb, uint64 m, uint64 n, uint64 k,
963                           float alpha, const DeviceMemory<float> &a, int lda,
964                           const DeviceMemory<float> &b, int ldb, float beta,
965                           DeviceMemory<float> *c, int ldc) = 0;
966   virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa,
967                           blas::Transpose transb, uint64 m, uint64 n, uint64 k,
968                           double alpha, const DeviceMemory<double> &a, int lda,
969                           const DeviceMemory<double> &b, int ldb, double beta,
970                           DeviceMemory<double> *c, int ldc) = 0;
971   virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa,
972                           blas::Transpose transb, uint64 m, uint64 n, uint64 k,
973                           std::complex<float> alpha,
974                           const DeviceMemory<std::complex<float>> &a, int lda,
975                           const DeviceMemory<std::complex<float>> &b, int ldb,
976                           std::complex<float> beta,
977                           DeviceMemory<std::complex<float>> *c, int ldc) = 0;
978   virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa,
979                           blas::Transpose transb, uint64 m, uint64 n, uint64 k,
980                           std::complex<double> alpha,
981                           const DeviceMemory<std::complex<double>> &a, int lda,
982                           const DeviceMemory<std::complex<double>> &b, int ldb,
983                           std::complex<double> beta,
984                           DeviceMemory<std::complex<double>> *c, int ldc) = 0;
985 
986   virtual bool DoBlasGemmWithProfiling(
987       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
988       uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
989       int lda, const DeviceMemory<Eigen::half> &b, int ldb, float beta,
990       DeviceMemory<Eigen::half> *c, int ldc,
991       ProfileResult *output_profile_result) = 0;
992   virtual bool DoBlasGemmWithProfiling(
993       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
994       uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
995       const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
996       int ldc, ProfileResult *output_profile_result) = 0;
997   virtual bool DoBlasGemmWithProfiling(
998       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
999       uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
1000       const DeviceMemory<double> &b, int ldb, double beta,
1001       DeviceMemory<double> *c, int ldc,
1002       ProfileResult *output_profile_result) = 0;
1003   virtual bool DoBlasGemmWithProfiling(
1004       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1005       uint64 n, uint64 k, std::complex<float> alpha,
1006       const DeviceMemory<std::complex<float>> &a, int lda,
1007       const DeviceMemory<std::complex<float>> &b, int ldb,
1008       std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
1009       ProfileResult *output_profile_result) = 0;
1010   virtual bool DoBlasGemmWithProfiling(
1011       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1012       uint64 n, uint64 k, std::complex<double> alpha,
1013       const DeviceMemory<std::complex<double>> &a, int lda,
1014       const DeviceMemory<std::complex<double>> &b, int ldb,
1015       std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
1016       ProfileResult *output_profile_result) = 0;
1017 
1018   // Gets a list of supported algorithms for DoBlasGemmWithAlgorithm.
1019   virtual bool GetBlasGemmAlgorithms(
1020       std::vector<AlgorithmType> *out_algorithms) = 0;
1021 
1022   // Like DoBlasGemm, but accepts an algorithm and an compute type.
1023   //
1024   // The compute type lets you say (e.g.) that the inputs and outputs are
1025   // Eigen::halfs, but you want the internal computations to be done with
1026   // float32 precision.
1027   //
1028   // Note the subtle difference in the version that accepts Eigen:::half --
1029   // alpha and beta have type const Eigen::half&, not float.
1030   //
1031   // If output_profile_result is not null, a failure here does not put the
1032   // stream in a failure state.  Instead, success/failure is indicated by
1033   // output_profile_result->is_valid().  This lets you use this function for
1034   // choosing the best algorithm among many (some of which may fail) without
1035   // creating a new Stream for each attempt.
1036   virtual bool DoBlasGemmWithAlgorithm(
1037       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1038       uint64 n, uint64 k, const HostOrDeviceScalar<int> &alpha,
1039       const DeviceMemory<int8> &a, int lda, const DeviceMemory<int8> &b,
1040       int ldb, const HostOrDeviceScalar<int> &beta, DeviceMemory<int32> *c,
1041       int ldc, ComputationType computation_type, AlgorithmType algorithm,
1042       ProfileResult *output_profile_result) = 0;
1043   virtual bool DoBlasGemmWithAlgorithm(
1044       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1045       uint64 n, uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha,
1046       const DeviceMemory<Eigen::half> &a, int lda,
1047       const DeviceMemory<Eigen::half> &b, int ldb,
1048       const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c,
1049       int ldc, ComputationType computation_type, AlgorithmType algorithm,
1050       ProfileResult *output_profile_result) = 0;
1051   virtual bool DoBlasGemmWithAlgorithm(
1052       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1053       uint64 n, uint64 k, const HostOrDeviceScalar<float> &alpha,
1054       const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,
1055       int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c,
1056       int ldc, ComputationType computation_type, AlgorithmType algorithm,
1057       ProfileResult *output_profile_result) = 0;
1058   virtual bool DoBlasGemmWithAlgorithm(
1059       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1060       uint64 n, uint64 k, const HostOrDeviceScalar<double> &alpha,
1061       const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,
1062       int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c,
1063       int ldc, ComputationType computation_type, AlgorithmType algorithm,
1064       ProfileResult *output_profile_result) = 0;
1065   virtual bool DoBlasGemmWithAlgorithm(
1066       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1067       uint64 n, uint64 k, const HostOrDeviceScalar<std::complex<float>> &alpha,
1068       const DeviceMemory<std::complex<float>> &a, int lda,
1069       const DeviceMemory<std::complex<float>> &b, int ldb,
1070       const HostOrDeviceScalar<std::complex<float>> &beta,
1071       DeviceMemory<std::complex<float>> *c, int ldc,
1072       ComputationType computation_type, AlgorithmType algorithm,
1073       ProfileResult *output_profile_result) = 0;
1074   virtual bool DoBlasGemmWithAlgorithm(
1075       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1076       uint64 n, uint64 k, const HostOrDeviceScalar<std::complex<double>> &alpha,
1077       const DeviceMemory<std::complex<double>> &a, int lda,
1078       const DeviceMemory<std::complex<double>> &b, int ldb,
1079       const HostOrDeviceScalar<std::complex<double>> &beta,
1080       DeviceMemory<std::complex<double>> *c, int ldc,
1081       ComputationType computation_type, AlgorithmType algorithm,
1082       ProfileResult *output_profile_result) = 0;
1083 
1084   // Computes a batch of matrix-matrix product with general matrices.
1085   // This is a batched version of DoBlasGemm.
1086   // The batched GEMM computes matrix product for each input/output in a, b,
1087   // and c, which contain batch_count DeviceMemory objects.
1088   virtual bool DoBlasGemmBatched(
1089       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1090       uint64 n, uint64 k, float alpha,
1091       const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
1092       const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb,
1093       float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,
1094       int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0;
1095   virtual bool DoBlasGemmBatched(
1096       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1097       uint64 n, uint64 k, float alpha,
1098       const port::ArraySlice<DeviceMemory<float> *> &a, int lda,
1099       const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, float beta,
1100       const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
1101       int batch_count, ScratchAllocator *scratch_allocator) = 0;
1102   virtual bool DoBlasGemmBatched(
1103       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1104       uint64 n, uint64 k, double alpha,
1105       const port::ArraySlice<DeviceMemory<double> *> &a, int lda,
1106       const port::ArraySlice<DeviceMemory<double> *> &b, int ldb, double beta,
1107       const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
1108       int batch_count, ScratchAllocator *scratch_allocator) = 0;
1109   virtual bool DoBlasGemmBatched(
1110       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1111       uint64 n, uint64 k, std::complex<float> alpha,
1112       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
1113       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
1114       std::complex<float> beta,
1115       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
1116       int batch_count, ScratchAllocator *scratch_allocator) = 0;
1117   virtual bool DoBlasGemmBatched(
1118       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1119       uint64 n, uint64 k, std::complex<double> alpha,
1120       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
1121       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
1122       std::complex<double> beta,
1123       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
1124       int batch_count, ScratchAllocator *scratch_allocator) = 0;
1125 
1126   // Batched gemm with strides instead of pointer arrays.
1127   virtual bool DoBlasGemmStridedBatched(
1128       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1129       uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
1130       int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb,
1131       int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc,
1132       int64 stride_c, int batch_count) = 0;
1133   virtual bool DoBlasGemmStridedBatched(
1134       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1135       uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
1136       int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
1137       float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
1138       int batch_count) = 0;
1139   virtual bool DoBlasGemmStridedBatched(
1140       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1141       uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
1142       int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
1143       double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
1144       int batch_count) = 0;
1145   virtual bool DoBlasGemmStridedBatched(
1146       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1147       uint64 n, uint64 k, std::complex<float> alpha,
1148       const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
1149       const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
1150       std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
1151       int64 stride_c, int batch_count) = 0;
1152   virtual bool DoBlasGemmStridedBatched(
1153       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1154       uint64 n, uint64 k, std::complex<double> alpha,
1155       const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
1156       const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
1157       std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
1158       int64 stride_c, int batch_count) = 0;
1159 
1160   // Computes a matrix-matrix product where one input matrix is Hermitian:
1161   //
1162   //     c <- alpha * a * b + beta * c,
1163   // or
1164   //     c <- alpha * b * a + beta * c,
1165   //
1166   // alpha and beta are scalars; a is a Hermitian matrix; b and c are m-by-n
1167   // matrices.
1168   virtual bool DoBlasHemm(Stream *stream, blas::Side side,
1169                           blas::UpperLower uplo, uint64 m, uint64 n,
1170                           std::complex<float> alpha,
1171                           const DeviceMemory<std::complex<float>> &a, int lda,
1172                           const DeviceMemory<std::complex<float>> &b, int ldb,
1173                           std::complex<float> beta,
1174                           DeviceMemory<std::complex<float>> *c, int ldc) = 0;
1175   virtual bool DoBlasHemm(Stream *stream, blas::Side side,
1176                           blas::UpperLower uplo, uint64 m, uint64 n,
1177                           std::complex<double> alpha,
1178                           const DeviceMemory<std::complex<double>> &a, int lda,
1179                           const DeviceMemory<std::complex<double>> &b, int ldb,
1180                           std::complex<double> beta,
1181                           DeviceMemory<std::complex<double>> *c, int ldc) = 0;
1182 
1183   // Performs a Hermitian rank-k update.
1184   //
1185   //     c <- alpha * a * conj(a') + beta * c,
1186   // or
1187   //     c <- alpha * conj(a') * a + beta * c,
1188   //
1189   // alpha and beta are scalars; c is a n-by-n Hermitian matrix; a is an n-by-k
1190   // matrix in the first case and a k-by-n matrix in the second case.
1191   virtual bool DoBlasHerk(Stream *stream, blas::UpperLower uplo,
1192                           blas::Transpose trans, uint64 n, uint64 k,
1193                           float alpha,
1194                           const DeviceMemory<std::complex<float>> &a, int lda,
1195                           float beta, DeviceMemory<std::complex<float>> *c,
1196                           int ldc) = 0;
1197   virtual bool DoBlasHerk(Stream *stream, blas::UpperLower uplo,
1198                           blas::Transpose trans, uint64 n, uint64 k,
1199                           double alpha,
1200                           const DeviceMemory<std::complex<double>> &a, int lda,
1201                           double beta, DeviceMemory<std::complex<double>> *c,
1202                           int ldc) = 0;
1203 
1204   // Performs a Hermitian rank-2k update.
1205   //
1206   //     c <- alpha * a * conj(b') + conj(alpha) * b * conj(a') + beta * c,
1207   // or
1208   //     c <- alpha * conj(b') * a + conj(alpha) * conj(a') * b + beta * c,
1209   //
1210   // alpha and beta are scalars; c is a n-by-n Hermitian matrix; a and b are
1211   // n-by-k matrices in the first case and k-by-n matrices in the second case.
1212   virtual bool DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
1213                            blas::Transpose trans, uint64 n, uint64 k,
1214                            std::complex<float> alpha,
1215                            const DeviceMemory<std::complex<float>> &a, int lda,
1216                            const DeviceMemory<std::complex<float>> &b, int ldb,
1217                            float beta, DeviceMemory<std::complex<float>> *c,
1218                            int ldc) = 0;
1219   virtual bool DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
1220                            blas::Transpose trans, uint64 n, uint64 k,
1221                            std::complex<double> alpha,
1222                            const DeviceMemory<std::complex<double>> &a, int lda,
1223                            const DeviceMemory<std::complex<double>> &b, int ldb,
1224                            double beta, DeviceMemory<std::complex<double>> *c,
1225                            int ldc) = 0;
1226 
1227   // Computes a matrix-matrix product where one input matrix is symmetric.
1228   //
1229   //     c <- alpha * a * b + beta * c,
1230   // or
1231   //     c <- alpha * b * a + beta * c,
1232   //
1233   // alpha and beta are scalars; a is a symmetric matrix; b and c are m-by-n
1234   // matrices.
1235   virtual bool DoBlasSymm(Stream *stream, blas::Side side,
1236                           blas::UpperLower uplo, uint64 m, uint64 n,
1237                           float alpha, const DeviceMemory<float> &a, int lda,
1238                           const DeviceMemory<float> &b, int ldb, float beta,
1239                           DeviceMemory<float> *c, int ldc) = 0;
1240   virtual bool DoBlasSymm(Stream *stream, blas::Side side,
1241                           blas::UpperLower uplo, uint64 m, uint64 n,
1242                           double alpha, const DeviceMemory<double> &a, int lda,
1243                           const DeviceMemory<double> &b, int ldb, double beta,
1244                           DeviceMemory<double> *c, int ldc) = 0;
1245   virtual bool DoBlasSymm(Stream *stream, blas::Side side,
1246                           blas::UpperLower uplo, uint64 m, uint64 n,
1247                           std::complex<float> alpha,
1248                           const DeviceMemory<std::complex<float>> &a, int lda,
1249                           const DeviceMemory<std::complex<float>> &b, int ldb,
1250                           std::complex<float> beta,
1251                           DeviceMemory<std::complex<float>> *c, int ldc) = 0;
1252   virtual bool DoBlasSymm(Stream *stream, blas::Side side,
1253                           blas::UpperLower uplo, uint64 m, uint64 n,
1254                           std::complex<double> alpha,
1255                           const DeviceMemory<std::complex<double>> &a, int lda,
1256                           const DeviceMemory<std::complex<double>> &b, int ldb,
1257                           std::complex<double> beta,
1258                           DeviceMemory<std::complex<double>> *c, int ldc) = 0;
1259 
1260   // Performs a symmetric rank-k update.
1261   //
1262   //     c <- alpha * a * a' + beta * c,
1263   // or
1264   //     c <- alpha * a' * a + beta * c,
1265   //
1266   // alpha and beta are scalars; c is a n-by-n symmetric matrix; a is an n-by-k
1267   // matrix in the first case and a k-by-n matrix in the second case.
1268   virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
1269                           blas::Transpose trans, uint64 n, uint64 k,
1270                           float alpha, const DeviceMemory<float> &a, int lda,
1271                           float beta, DeviceMemory<float> *c, int ldc) = 0;
1272   virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
1273                           blas::Transpose trans, uint64 n, uint64 k,
1274                           double alpha, const DeviceMemory<double> &a, int lda,
1275                           double beta, DeviceMemory<double> *c, int ldc) = 0;
1276   virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
1277                           blas::Transpose trans, uint64 n, uint64 k,
1278                           std::complex<float> alpha,
1279                           const DeviceMemory<std::complex<float>> &a, int lda,
1280                           std::complex<float> beta,
1281                           DeviceMemory<std::complex<float>> *c, int ldc) = 0;
1282   virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
1283                           blas::Transpose trans, uint64 n, uint64 k,
1284                           std::complex<double> alpha,
1285                           const DeviceMemory<std::complex<double>> &a, int lda,
1286                           std::complex<double> beta,
1287                           DeviceMemory<std::complex<double>> *c, int ldc) = 0;
1288 
1289   // Performs a symmetric rank-2k update.
1290   //
1291   //     c <- alpha * a * b' + alpha * b * a' + beta * c,
1292   // or
1293   //     c <- alpha * b' * a + alpha * a' * b + beta * c,
1294   //
1295   // alpha and beta are scalars; c is a n-by-n symmetric matrix; a and b are
1296   // n-by-k matrices in the first case and k-by-n matrices in the second case.
1297   virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
1298                            blas::Transpose trans, uint64 n, uint64 k,
1299                            float alpha, const DeviceMemory<float> &a, int lda,
1300                            const DeviceMemory<float> &b, int ldb, float beta,
1301                            DeviceMemory<float> *c, int ldc) = 0;
1302   virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
1303                            blas::Transpose trans, uint64 n, uint64 k,
1304                            double alpha, const DeviceMemory<double> &a, int lda,
1305                            const DeviceMemory<double> &b, int ldb, double beta,
1306                            DeviceMemory<double> *c, int ldc) = 0;
1307   virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
1308                            blas::Transpose trans, uint64 n, uint64 k,
1309                            std::complex<float> alpha,
1310                            const DeviceMemory<std::complex<float>> &a, int lda,
1311                            const DeviceMemory<std::complex<float>> &b, int ldb,
1312                            std::complex<float> beta,
1313                            DeviceMemory<std::complex<float>> *c, int ldc) = 0;
1314   virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
1315                            blas::Transpose trans, uint64 n, uint64 k,
1316                            std::complex<double> alpha,
1317                            const DeviceMemory<std::complex<double>> &a, int lda,
1318                            const DeviceMemory<std::complex<double>> &b, int ldb,
1319                            std::complex<double> beta,
1320                            DeviceMemory<std::complex<double>> *c, int ldc) = 0;
1321 
1322   // Computes a matrix-matrix product where one input matrix is triangular.
1323   //
1324   //     b <- alpha * op(a) * b,
1325   // or
1326   //     b <- alpha * b * op(a)
1327   //
1328   // alpha is a scalar; b is an m-by-n matrix; a is a unit, or non-unit, upper
1329   // or lower triangular matrix; op(a) is one of op(a) = a, or op(a) = a', or
1330   // op(a) = conj(a').
1331   virtual bool DoBlasTrmm(Stream *stream, blas::Side side,
1332                           blas::UpperLower uplo, blas::Transpose transa,
1333                           blas::Diagonal diag, uint64 m, uint64 n, float alpha,
1334                           const DeviceMemory<float> &a, int lda,
1335                           DeviceMemory<float> *b, int ldb) = 0;
1336   virtual bool DoBlasTrmm(Stream *stream, blas::Side side,
1337                           blas::UpperLower uplo, blas::Transpose transa,
1338                           blas::Diagonal diag, uint64 m, uint64 n, double alpha,
1339                           const DeviceMemory<double> &a, int lda,
1340                           DeviceMemory<double> *b, int ldb) = 0;
1341   virtual bool DoBlasTrmm(Stream *stream, blas::Side side,
1342                           blas::UpperLower uplo, blas::Transpose transa,
1343                           blas::Diagonal diag, uint64 m, uint64 n,
1344                           std::complex<float> alpha,
1345                           const DeviceMemory<std::complex<float>> &a, int lda,
1346                           DeviceMemory<std::complex<float>> *b, int ldb) = 0;
1347   virtual bool DoBlasTrmm(Stream *stream, blas::Side side,
1348                           blas::UpperLower uplo, blas::Transpose transa,
1349                           blas::Diagonal diag, uint64 m, uint64 n,
1350                           std::complex<double> alpha,
1351                           const DeviceMemory<std::complex<double>> &a, int lda,
1352                           DeviceMemory<std::complex<double>> *b, int ldb) = 0;
1353 
1354   // Solves a triangular matrix equation.
1355   //
1356   //     op(a) * x = alpha * b,
1357   // or
1358   //     x * op(a) = alpha * b
1359   //
1360   // alpha is a scalar; x and b are m-by-n matrices; a is a unit, or non-unit,
1361   // upper or lower triangular matrix; op(a) is one of op(a) = a, or op(a) = a',
1362   // or op(a) = conj(a').
1363   virtual bool DoBlasTrsm(Stream *stream, blas::Side side,
1364                           blas::UpperLower uplo, blas::Transpose transa,
1365                           blas::Diagonal diag, uint64 m, uint64 n, float alpha,
1366                           const DeviceMemory<float> &a, int lda,
1367                           DeviceMemory<float> *b, int ldb) = 0;
1368   virtual bool DoBlasTrsm(Stream *stream, blas::Side side,
1369                           blas::UpperLower uplo, blas::Transpose transa,
1370                           blas::Diagonal diag, uint64 m, uint64 n, double alpha,
1371                           const DeviceMemory<double> &a, int lda,
1372                           DeviceMemory<double> *b, int ldb) = 0;
1373   virtual bool DoBlasTrsm(Stream *stream, blas::Side side,
1374                           blas::UpperLower uplo, blas::Transpose transa,
1375                           blas::Diagonal diag, uint64 m, uint64 n,
1376                           std::complex<float> alpha,
1377                           const DeviceMemory<std::complex<float>> &a, int lda,
1378                           DeviceMemory<std::complex<float>> *b, int ldb) = 0;
1379   virtual bool DoBlasTrsm(Stream *stream, blas::Side side,
1380                           blas::UpperLower uplo, blas::Transpose transa,
1381                           blas::Diagonal diag, uint64 m, uint64 n,
1382                           std::complex<double> alpha,
1383                           const DeviceMemory<std::complex<double>> &a, int lda,
1384                           DeviceMemory<std::complex<double>> *b, int ldb) = 0;
1385 
1386   virtual port::Status GetVersion(string *version) = 0;
1387 
1388  protected:
BlasSupport()1389   BlasSupport() {}
1390 
1391  private:
1392   SE_DISALLOW_COPY_AND_ASSIGN(BlasSupport);
1393 };
1394 
1395 // Macro used to quickly declare overrides for abstract virtuals in the
1396 // BlasSupport base class.
1397 #define TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES                  \
1398   bool DoBlasAsum(Stream *stream, uint64 elem_count,                           \
1399                   const DeviceMemory<float> &x, int incx,                      \
1400                   DeviceMemory<float> *result) override;                       \
1401   bool DoBlasAsum(Stream *stream, uint64 elem_count,                           \
1402                   const DeviceMemory<double> &x, int incx,                     \
1403                   DeviceMemory<double> *result) override;                      \
1404   bool DoBlasAsum(Stream *stream, uint64 elem_count,                           \
1405                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1406                   DeviceMemory<float> *result) override;                       \
1407   bool DoBlasAsum(Stream *stream, uint64 elem_count,                           \
1408                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1409                   DeviceMemory<double> *result) override;                      \
1410   bool DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha,              \
1411                   const DeviceMemory<float> &x, int incx,                      \
1412                   DeviceMemory<float> *y, int incy) override;                  \
1413   bool DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha,             \
1414                   const DeviceMemory<double> &x, int incx,                     \
1415                   DeviceMemory<double> *y, int incy) override;                 \
1416   bool DoBlasAxpy(Stream *stream, uint64 elem_count,                           \
1417                   std::complex<float> alpha,                                   \
1418                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1419                   DeviceMemory<std::complex<float>> *y, int incy) override;    \
1420   bool DoBlasAxpy(Stream *stream, uint64 elem_count,                           \
1421                   std::complex<double> alpha,                                  \
1422                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1423                   DeviceMemory<std::complex<double>> *y, int incy) override;   \
1424   bool DoBlasCopy(Stream *stream, uint64 elem_count,                           \
1425                   const DeviceMemory<float> &x, int incx,                      \
1426                   DeviceMemory<float> *y, int incy) override;                  \
1427   bool DoBlasCopy(Stream *stream, uint64 elem_count,                           \
1428                   const DeviceMemory<double> &x, int incx,                     \
1429                   DeviceMemory<double> *y, int incy) override;                 \
1430   bool DoBlasCopy(Stream *stream, uint64 elem_count,                           \
1431                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1432                   DeviceMemory<std::complex<float>> *y, int incy) override;    \
1433   bool DoBlasCopy(Stream *stream, uint64 elem_count,                           \
1434                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1435                   DeviceMemory<std::complex<double>> *y, int incy) override;   \
1436   bool DoBlasDot(Stream *stream, uint64 elem_count,                            \
1437                  const DeviceMemory<float> &x, int incx,                       \
1438                  const DeviceMemory<float> &y, int incy,                       \
1439                  DeviceMemory<float> *result) override;                        \
1440   bool DoBlasDot(Stream *stream, uint64 elem_count,                            \
1441                  const DeviceMemory<double> &x, int incx,                      \
1442                  const DeviceMemory<double> &y, int incy,                      \
1443                  DeviceMemory<double> *result) override;                       \
1444   bool DoBlasDotc(Stream *stream, uint64 elem_count,                           \
1445                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1446                   const DeviceMemory<std::complex<float>> &y, int incy,        \
1447                   DeviceMemory<std::complex<float>> *result) override;         \
1448   bool DoBlasDotc(Stream *stream, uint64 elem_count,                           \
1449                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1450                   const DeviceMemory<std::complex<double>> &y, int incy,       \
1451                   DeviceMemory<std::complex<double>> *result) override;        \
1452   bool DoBlasDotu(Stream *stream, uint64 elem_count,                           \
1453                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1454                   const DeviceMemory<std::complex<float>> &y, int incy,        \
1455                   DeviceMemory<std::complex<float>> *result) override;         \
1456   bool DoBlasDotu(Stream *stream, uint64 elem_count,                           \
1457                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1458                   const DeviceMemory<std::complex<double>> &y, int incy,       \
1459                   DeviceMemory<std::complex<double>> *result) override;        \
1460   bool DoBlasNrm2(Stream *stream, uint64 elem_count,                           \
1461                   const DeviceMemory<float> &x, int incx,                      \
1462                   DeviceMemory<float> *result) override;                       \
1463   bool DoBlasNrm2(Stream *stream, uint64 elem_count,                           \
1464                   const DeviceMemory<double> &x, int incx,                     \
1465                   DeviceMemory<double> *result) override;                      \
1466   bool DoBlasNrm2(Stream *stream, uint64 elem_count,                           \
1467                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1468                   DeviceMemory<float> *result) override;                       \
1469   bool DoBlasNrm2(Stream *stream, uint64 elem_count,                           \
1470                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1471                   DeviceMemory<double> *result) override;                      \
1472   bool DoBlasRot(Stream *stream, uint64 elem_count, DeviceMemory<float> *x,    \
1473                  int incx, DeviceMemory<float> *y, int incy, float c, float s) \
1474       override;                                                                \
1475   bool DoBlasRot(Stream *stream, uint64 elem_count, DeviceMemory<double> *x,   \
1476                  int incx, DeviceMemory<double> *y, int incy, double c,        \
1477                  double s) override;                                           \
1478   bool DoBlasRot(Stream *stream, uint64 elem_count,                            \
1479                  DeviceMemory<std::complex<float>> *x, int incx,               \
1480                  DeviceMemory<std::complex<float>> *y, int incy, float c,      \
1481                  float s) override;                                            \
1482   bool DoBlasRot(Stream *stream, uint64 elem_count,                            \
1483                  DeviceMemory<std::complex<double>> *x, int incx,              \
1484                  DeviceMemory<std::complex<double>> *y, int incy, double c,    \
1485                  double s) override;                                           \
1486   bool DoBlasRotg(Stream *stream, DeviceMemory<float> *a,                      \
1487                   DeviceMemory<float> *b, DeviceMemory<float> *c,              \
1488                   DeviceMemory<float> *s) override;                            \
1489   bool DoBlasRotg(Stream *stream, DeviceMemory<double> *a,                     \
1490                   DeviceMemory<double> *b, DeviceMemory<double> *c,            \
1491                   DeviceMemory<double> *s) override;                           \
1492   bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<float>> *a,        \
1493                   DeviceMemory<std::complex<float>> *b,                        \
1494                   DeviceMemory<float> *c,                                      \
1495                   DeviceMemory<std::complex<float>> *s) override;              \
1496   bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<double>> *a,       \
1497                   DeviceMemory<std::complex<double>> *b,                       \
1498                   DeviceMemory<double> *c,                                     \
1499                   DeviceMemory<std::complex<double>> *s) override;             \
1500   bool DoBlasRotm(Stream *stream, uint64 elem_count, DeviceMemory<float> *x,   \
1501                   int incx, DeviceMemory<float> *y, int incy,                  \
1502                   const DeviceMemory<float> &param) override;                  \
1503   bool DoBlasRotm(Stream *stream, uint64 elem_count, DeviceMemory<double> *x,  \
1504                   int incx, DeviceMemory<double> *y, int incy,                 \
1505                   const DeviceMemory<double> &param) override;                 \
1506   bool DoBlasRotmg(Stream *stream, DeviceMemory<float> *d1,                    \
1507                    DeviceMemory<float> *d2, DeviceMemory<float> *x1,           \
1508                    const DeviceMemory<float> &y1, DeviceMemory<float> *param)  \
1509       override;                                                                \
1510   bool DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1,                   \
1511                    DeviceMemory<double> *d2, DeviceMemory<double> *x1,         \
1512                    const DeviceMemory<double> &y1,                             \
1513                    DeviceMemory<double> *param) override;                      \
1514   bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha,              \
1515                   DeviceMemory<float> *x, int incx) override;                  \
1516   bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha,             \
1517                   DeviceMemory<double> *x, int incx) override;                 \
1518   bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha,              \
1519                   DeviceMemory<std::complex<float>> *x, int incx) override;    \
1520   bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha,             \
1521                   DeviceMemory<std::complex<double>> *x, int incx) override;   \
1522   bool DoBlasScal(Stream *stream, uint64 elem_count,                           \
1523                   std::complex<float> alpha,                                   \
1524                   DeviceMemory<std::complex<float>> *x, int incx) override;    \
1525   bool DoBlasScal(Stream *stream, uint64 elem_count,                           \
1526                   std::complex<double> alpha,                                  \
1527                   DeviceMemory<std::complex<double>> *x, int incx) override;   \
1528   bool DoBlasSwap(Stream *stream, uint64 elem_count, DeviceMemory<float> *x,   \
1529                   int incx, DeviceMemory<float> *y, int incy) override;        \
1530   bool DoBlasSwap(Stream *stream, uint64 elem_count, DeviceMemory<double> *x,  \
1531                   int incx, DeviceMemory<double> *y, int incy) override;       \
1532   bool DoBlasSwap(Stream *stream, uint64 elem_count,                           \
1533                   DeviceMemory<std::complex<float>> *x, int incx,              \
1534                   DeviceMemory<std::complex<float>> *y, int incy) override;    \
1535   bool DoBlasSwap(Stream *stream, uint64 elem_count,                           \
1536                   DeviceMemory<std::complex<double>> *x, int incx,             \
1537                   DeviceMemory<std::complex<double>> *y, int incy) override;   \
1538   bool DoBlasIamax(Stream *stream, uint64 elem_count,                          \
1539                    const DeviceMemory<float> &x, int incx,                     \
1540                    DeviceMemory<int> *result) override;                        \
1541   bool DoBlasIamax(Stream *stream, uint64 elem_count,                          \
1542                    const DeviceMemory<double> &x, int incx,                    \
1543                    DeviceMemory<int> *result) override;                        \
1544   bool DoBlasIamax(Stream *stream, uint64 elem_count,                          \
1545                    const DeviceMemory<std::complex<float>> &x, int incx,       \
1546                    DeviceMemory<int> *result) override;                        \
1547   bool DoBlasIamax(Stream *stream, uint64 elem_count,                          \
1548                    const DeviceMemory<std::complex<double>> &x, int incx,      \
1549                    DeviceMemory<int> *result) override;                        \
1550   bool DoBlasIamin(Stream *stream, uint64 elem_count,                          \
1551                    const DeviceMemory<float> &x, int incx,                     \
1552                    DeviceMemory<int> *result) override;                        \
1553   bool DoBlasIamin(Stream *stream, uint64 elem_count,                          \
1554                    const DeviceMemory<double> &x, int incx,                    \
1555                    DeviceMemory<int> *result) override;                        \
1556   bool DoBlasIamin(Stream *stream, uint64 elem_count,                          \
1557                    const DeviceMemory<std::complex<float>> &x, int incx,       \
1558                    DeviceMemory<int> *result) override;                        \
1559   bool DoBlasIamin(Stream *stream, uint64 elem_count,                          \
1560                    const DeviceMemory<std::complex<double>> &x, int incx,      \
1561                    DeviceMemory<int> *result) override;                        \
1562   bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1563                   uint64 kl, uint64 ku, float alpha,                           \
1564                   const DeviceMemory<float> &a, int lda,                       \
1565                   const DeviceMemory<float> &x, int incx, float beta,          \
1566                   DeviceMemory<float> *y, int incy) override;                  \
1567   bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1568                   uint64 kl, uint64 ku, double alpha,                          \
1569                   const DeviceMemory<double> &a, int lda,                      \
1570                   const DeviceMemory<double> &x, int incx, double beta,        \
1571                   DeviceMemory<double> *y, int incy) override;                 \
1572   bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1573                   uint64 kl, uint64 ku, std::complex<float> alpha,             \
1574                   const DeviceMemory<std::complex<float>> &a, int lda,         \
1575                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1576                   std::complex<float> beta,                                    \
1577                   DeviceMemory<std::complex<float>> *y, int incy) override;    \
1578   bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1579                   uint64 kl, uint64 ku, std::complex<double> alpha,            \
1580                   const DeviceMemory<std::complex<double>> &a, int lda,        \
1581                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1582                   std::complex<double> beta,                                   \
1583                   DeviceMemory<std::complex<double>> *y, int incy) override;   \
1584   bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1585                   float alpha, const DeviceMemory<float> &a, int lda,          \
1586                   const DeviceMemory<float> &x, int incx, float beta,          \
1587                   DeviceMemory<float> *y, int incy) override;                  \
1588   bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1589                   double alpha, const DeviceMemory<double> &a, int lda,        \
1590                   const DeviceMemory<double> &x, int incx, double beta,        \
1591                   DeviceMemory<double> *y, int incy) override;                 \
1592   bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1593                   std::complex<float> alpha,                                   \
1594                   const DeviceMemory<std::complex<float>> &a, int lda,         \
1595                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1596                   std::complex<float> beta,                                    \
1597                   DeviceMemory<std::complex<float>> *y, int incy) override;    \
1598   bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1599                   std::complex<double> alpha,                                  \
1600                   const DeviceMemory<std::complex<double>> &a, int lda,        \
1601                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1602                   std::complex<double> beta,                                   \
1603                   DeviceMemory<std::complex<double>> *y, int incy) override;   \
1604   bool DoBlasGemvWithProfiling(                                                \
1605       Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha,  \
1606       const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,     \
1607       int incx, float beta, DeviceMemory<float> *y, int incy,                  \
1608       blas::ProfileResult *output_profile_result) override;                    \
1609   bool DoBlasGemvWithProfiling(                                                \
1610       Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha, \
1611       const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,   \
1612       int incx, double beta, DeviceMemory<double> *y, int incy,                \
1613       blas::ProfileResult *output_profile_result) override;                    \
1614   bool DoBlasGemvWithProfiling(                                                \
1615       Stream *stream, blas::Transpose trans, uint64 m, uint64 n,               \
1616       std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a,   \
1617       int lda, const DeviceMemory<std::complex<float>> &x, int incx,           \
1618       std::complex<float> beta, DeviceMemory<std::complex<float>> *y,          \
1619       int incy, blas::ProfileResult *output_profile_result) override;          \
1620   bool DoBlasGemvWithProfiling(                                                \
1621       Stream *stream, blas::Transpose trans, uint64 m, uint64 n,               \
1622       std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a, \
1623       int lda, const DeviceMemory<std::complex<double>> &x, int incx,          \
1624       std::complex<double> beta, DeviceMemory<std::complex<double>> *y,        \
1625       int incy, blas::ProfileResult *output_profile_result) override;          \
1626   bool DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha,              \
1627                  const DeviceMemory<float> &x, int incx,                       \
1628                  const DeviceMemory<float> &y, int incy,                       \
1629                  DeviceMemory<float> *a, int lda) override;                    \
1630   bool DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha,             \
1631                  const DeviceMemory<double> &x, int incx,                      \
1632                  const DeviceMemory<double> &y, int incy,                      \
1633                  DeviceMemory<double> *a, int lda) override;                   \
1634   bool DoBlasGerc(Stream *stream, uint64 m, uint64 n,                          \
1635                   std::complex<float> alpha,                                   \
1636                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1637                   const DeviceMemory<std::complex<float>> &y, int incy,        \
1638                   DeviceMemory<std::complex<float>> *a, int lda) override;     \
1639   bool DoBlasGerc(Stream *stream, uint64 m, uint64 n,                          \
1640                   std::complex<double> alpha,                                  \
1641                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1642                   const DeviceMemory<std::complex<double>> &y, int incy,       \
1643                   DeviceMemory<std::complex<double>> *a, int lda) override;    \
1644   bool DoBlasGeru(Stream *stream, uint64 m, uint64 n,                          \
1645                   std::complex<float> alpha,                                   \
1646                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1647                   const DeviceMemory<std::complex<float>> &y, int incy,        \
1648                   DeviceMemory<std::complex<float>> *a, int lda) override;     \
1649   bool DoBlasGeru(Stream *stream, uint64 m, uint64 n,                          \
1650                   std::complex<double> alpha,                                  \
1651                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1652                   const DeviceMemory<std::complex<double>> &y, int incy,       \
1653                   DeviceMemory<std::complex<double>> *a, int lda) override;    \
1654   bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k,   \
1655                   std::complex<float> alpha,                                   \
1656                   const DeviceMemory<std::complex<float>> &a, int lda,         \
1657                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1658                   std::complex<float> beta,                                    \
1659                   DeviceMemory<std::complex<float>> *y, int incy) override;    \
1660   bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k,   \
1661                   std::complex<double> alpha,                                  \
1662                   const DeviceMemory<std::complex<double>> &a, int lda,        \
1663                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1664                   std::complex<double> beta,                                   \
1665                   DeviceMemory<std::complex<double>> *y, int incy) override;   \
1666   bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1667                   std::complex<float> alpha,                                   \
1668                   const DeviceMemory<std::complex<float>> &a, int lda,         \
1669                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1670                   std::complex<float> beta,                                    \
1671                   DeviceMemory<std::complex<float>> *y, int incy) override;    \
1672   bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1673                   std::complex<double> alpha,                                  \
1674                   const DeviceMemory<std::complex<double>> &a, int lda,        \
1675                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1676                   std::complex<double> beta,                                   \
1677                   DeviceMemory<std::complex<double>> *y, int incy) override;   \
1678   bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \
1679                  const DeviceMemory<std::complex<float>> &x, int incx,         \
1680                  DeviceMemory<std::complex<float>> *a, int lda) override;      \
1681   bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,              \
1682                  double alpha, const DeviceMemory<std::complex<double>> &x,    \
1683                  int incx, DeviceMemory<std::complex<double>> *a, int lda)     \
1684       override;                                                                \
1685   bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1686                   std::complex<float> alpha,                                   \
1687                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1688                   const DeviceMemory<std::complex<float>> &y, int incy,        \
1689                   DeviceMemory<std::complex<float>> *a, int lda) override;     \
1690   bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1691                   std::complex<double> alpha,                                  \
1692                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1693                   const DeviceMemory<std::complex<double>> &y, int incy,       \
1694                   DeviceMemory<std::complex<double>> *a, int lda) override;    \
1695   bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1696                   std::complex<float> alpha,                                   \
1697                   const DeviceMemory<std::complex<float>> &ap,                 \
1698                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1699                   std::complex<float> beta,                                    \
1700                   DeviceMemory<std::complex<float>> *y, int incy) override;    \
1701   bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1702                   std::complex<double> alpha,                                  \
1703                   const DeviceMemory<std::complex<double>> &ap,                \
1704                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1705                   std::complex<double> beta,                                   \
1706                   DeviceMemory<std::complex<double>> *y, int incy) override;   \
1707   bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \
1708                  const DeviceMemory<std::complex<float>> &x, int incx,         \
1709                  DeviceMemory<std::complex<float>> *ap) override;              \
1710   bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,              \
1711                  double alpha, const DeviceMemory<std::complex<double>> &x,    \
1712                  int incx, DeviceMemory<std::complex<double>> *ap) override;   \
1713   bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1714                   std::complex<float> alpha,                                   \
1715                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1716                   const DeviceMemory<std::complex<float>> &y, int incy,        \
1717                   DeviceMemory<std::complex<float>> *ap) override;             \
1718   bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1719                   std::complex<double> alpha,                                  \
1720                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1721                   const DeviceMemory<std::complex<double>> &y, int incy,       \
1722                   DeviceMemory<std::complex<double>> *ap) override;            \
1723   bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k,   \
1724                   float alpha, const DeviceMemory<float> &a, int lda,          \
1725                   const DeviceMemory<float> &x, int incx, float beta,          \
1726                   DeviceMemory<float> *y, int incy) override;                  \
1727   bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k,   \
1728                   double alpha, const DeviceMemory<double> &a, int lda,        \
1729                   const DeviceMemory<double> &x, int incx, double beta,        \
1730                   DeviceMemory<double> *y, int incy) override;                 \
1731   bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1732                   float alpha, const DeviceMemory<float> &ap,                  \
1733                   const DeviceMemory<float> &x, int incx, float beta,          \
1734                   DeviceMemory<float> *y, int incy) override;                  \
1735   bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1736                   double alpha, const DeviceMemory<double> &ap,                \
1737                   const DeviceMemory<double> &x, int incx, double beta,        \
1738                   DeviceMemory<double> *y, int incy) override;                 \
1739   bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \
1740                  const DeviceMemory<float> &x, int incx,                       \
1741                  DeviceMemory<float> *ap) override;                            \
1742   bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,              \
1743                  double alpha, const DeviceMemory<double> &x, int incx,        \
1744                  DeviceMemory<double> *ap) override;                           \
1745   bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1746                   float alpha, const DeviceMemory<float> &x, int incx,         \
1747                   const DeviceMemory<float> &y, int incy,                      \
1748                   DeviceMemory<float> *ap) override;                           \
1749   bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1750                   double alpha, const DeviceMemory<double> &x, int incx,       \
1751                   const DeviceMemory<double> &y, int incy,                     \
1752                   DeviceMemory<double> *ap) override;                          \
1753   bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1754                   float alpha, const DeviceMemory<float> &a, int lda,          \
1755                   const DeviceMemory<float> &x, int incx, float beta,          \
1756                   DeviceMemory<float> *y, int incy) override;                  \
1757   bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1758                   double alpha, const DeviceMemory<double> &a, int lda,        \
1759                   const DeviceMemory<double> &x, int incx, double beta,        \
1760                   DeviceMemory<double> *y, int incy) override;                 \
1761   bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \
1762                  const DeviceMemory<float> &x, int incx,                       \
1763                  DeviceMemory<float> *a, int lda) override;                    \
1764   bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,              \
1765                  double alpha, const DeviceMemory<double> &x, int incx,        \
1766                  DeviceMemory<double> *a, int lda) override;                   \
1767   bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1768                   float alpha, const DeviceMemory<float> &x, int incx,         \
1769                   const DeviceMemory<float> &y, int incy,                      \
1770                   DeviceMemory<float> *a, int lda) override;                   \
1771   bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1772                   double alpha, const DeviceMemory<double> &x, int incx,       \
1773                   const DeviceMemory<double> &y, int incy,                     \
1774                   DeviceMemory<double> *a, int lda) override;                  \
1775   bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,                       \
1776                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1777                   uint64 k, const DeviceMemory<float> &a, int lda,             \
1778                   DeviceMemory<float> *x, int incx) override;                  \
1779   bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,                       \
1780                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1781                   uint64 k, const DeviceMemory<double> &a, int lda,            \
1782                   DeviceMemory<double> *x, int incx) override;                 \
1783   bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,                       \
1784                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1785                   uint64 k, const DeviceMemory<std::complex<float>> &a,        \
1786                   int lda, DeviceMemory<std::complex<float>> *x, int incx)     \
1787       override;                                                                \
1788   bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,                       \
1789                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1790                   uint64 k, const DeviceMemory<std::complex<double>> &a,       \
1791                   int lda, DeviceMemory<std::complex<double>> *x, int incx)    \
1792       override;                                                                \
1793   bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,                       \
1794                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1795                   uint64 k, const DeviceMemory<float> &a, int lda,             \
1796                   DeviceMemory<float> *x, int incx) override;                  \
1797   bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,                       \
1798                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1799                   uint64 k, const DeviceMemory<double> &a, int lda,            \
1800                   DeviceMemory<double> *x, int incx) override;                 \
1801   bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,                       \
1802                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1803                   uint64 k, const DeviceMemory<std::complex<float>> &a,        \
1804                   int lda, DeviceMemory<std::complex<float>> *x, int incx)     \
1805       override;                                                                \
1806   bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,                       \
1807                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1808                   uint64 k, const DeviceMemory<std::complex<double>> &a,       \
1809                   int lda, DeviceMemory<std::complex<double>> *x, int incx)    \
1810       override;                                                                \
1811   bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,                       \
1812                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1813                   const DeviceMemory<float> &ap, DeviceMemory<float> *x,       \
1814                   int incx) override;                                          \
1815   bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,                       \
1816                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1817                   const DeviceMemory<double> &ap, DeviceMemory<double> *x,     \
1818                   int incx) override;                                          \
1819   bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,                       \
1820                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1821                   const DeviceMemory<std::complex<float>> &ap,                 \
1822                   DeviceMemory<std::complex<float>> *x, int incx) override;    \
1823   bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,                       \
1824                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1825                   const DeviceMemory<std::complex<double>> &ap,                \
1826                   DeviceMemory<std::complex<double>> *x, int incx) override;   \
1827   bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,                       \
1828                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1829                   const DeviceMemory<float> &ap, DeviceMemory<float> *x,       \
1830                   int incx) override;                                          \
1831   bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,                       \
1832                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1833                   const DeviceMemory<double> &ap, DeviceMemory<double> *x,     \
1834                   int incx) override;                                          \
1835   bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,                       \
1836                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1837                   const DeviceMemory<std::complex<float>> &ap,                 \
1838                   DeviceMemory<std::complex<float>> *x, int incx) override;    \
1839   bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,                       \
1840                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1841                   const DeviceMemory<std::complex<double>> &ap,                \
1842                   DeviceMemory<std::complex<double>> *x, int incx) override;   \
1843   bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,                       \
1844                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1845                   const DeviceMemory<float> &a, int lda,                       \
1846                   DeviceMemory<float> *x, int incx) override;                  \
1847   bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,                       \
1848                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1849                   const DeviceMemory<double> &a, int lda,                      \
1850                   DeviceMemory<double> *x, int incx) override;                 \
1851   bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,                       \
1852                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1853                   const DeviceMemory<std::complex<float>> &a, int lda,         \
1854                   DeviceMemory<std::complex<float>> *x, int incx) override;    \
1855   bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,                       \
1856                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1857                   const DeviceMemory<std::complex<double>> &a, int lda,        \
1858                   DeviceMemory<std::complex<double>> *x, int incx) override;   \
1859   bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,                       \
1860                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1861                   const DeviceMemory<float> &a, int lda,                       \
1862                   DeviceMemory<float> *x, int incx) override;                  \
1863   bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,                       \
1864                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1865                   const DeviceMemory<double> &a, int lda,                      \
1866                   DeviceMemory<double> *x, int incx) override;                 \
1867   bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,                       \
1868                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1869                   const DeviceMemory<std::complex<float>> &a, int lda,         \
1870                   DeviceMemory<std::complex<float>> *x, int incx) override;    \
1871   bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,                       \
1872                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1873                   const DeviceMemory<std::complex<double>> &a, int lda,        \
1874                   DeviceMemory<std::complex<double>> *x, int incx) override;   \
1875   bool DoBlasGemm(Stream *stream, blas::Transpose transa,                      \
1876                   blas::Transpose transb, uint64 m, uint64 n, uint64 k,        \
1877                   float alpha, const DeviceMemory<Eigen::half> &a, int lda,    \
1878                   const DeviceMemory<Eigen::half> &b, int ldb, float beta,     \
1879                   DeviceMemory<Eigen::half> *c, int ldc) override;             \
1880   bool DoBlasGemm(Stream *stream, blas::Transpose transa,                      \
1881                   blas::Transpose transb, uint64 m, uint64 n, uint64 k,        \
1882                   float alpha, const DeviceMemory<float> &a, int lda,          \
1883                   const DeviceMemory<float> &b, int ldb, float beta,           \
1884                   DeviceMemory<float> *c, int ldc) override;                   \
1885   bool DoBlasGemm(Stream *stream, blas::Transpose transa,                      \
1886                   blas::Transpose transb, uint64 m, uint64 n, uint64 k,        \
1887                   double alpha, const DeviceMemory<double> &a, int lda,        \
1888                   const DeviceMemory<double> &b, int ldb, double beta,         \
1889                   DeviceMemory<double> *c, int ldc) override;                  \
1890   bool DoBlasGemm(Stream *stream, blas::Transpose transa,                      \
1891                   blas::Transpose transb, uint64 m, uint64 n, uint64 k,        \
1892                   std::complex<float> alpha,                                   \
1893                   const DeviceMemory<std::complex<float>> &a, int lda,         \
1894                   const DeviceMemory<std::complex<float>> &b, int ldb,         \
1895                   std::complex<float> beta,                                    \
1896                   DeviceMemory<std::complex<float>> *c, int ldc) override;     \
1897   bool DoBlasGemm(Stream *stream, blas::Transpose transa,                      \
1898                   blas::Transpose transb, uint64 m, uint64 n, uint64 k,        \
1899                   std::complex<double> alpha,                                  \
1900                   const DeviceMemory<std::complex<double>> &a, int lda,        \
1901                   const DeviceMemory<std::complex<double>> &b, int ldb,        \
1902                   std::complex<double> beta,                                   \
1903                   DeviceMemory<std::complex<double>> *c, int ldc) override;    \
1904   bool DoBlasGemmWithProfiling(                                                \
1905       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
1906       uint64 m, uint64 n, uint64 k, float alpha,                               \
1907       const DeviceMemory<Eigen::half> &a, int lda,                             \
1908       const DeviceMemory<Eigen::half> &b, int ldb, float beta,                 \
1909       DeviceMemory<Eigen::half> *c, int ldc,                                   \
1910       blas::ProfileResult *output_profile_result) override;                    \
1911   bool DoBlasGemmWithProfiling(                                                \
1912       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
1913       uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, \
1914       int lda, const DeviceMemory<float> &b, int ldb, float beta,              \
1915       DeviceMemory<float> *c, int ldc,                                         \
1916       blas::ProfileResult *output_profile_result) override;                    \
1917   bool DoBlasGemmWithProfiling(                                                \
1918       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
1919       uint64 m, uint64 n, uint64 k, double alpha,                              \
1920       const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,   \
1921       int ldb, double beta, DeviceMemory<double> *c, int ldc,                  \
1922       blas::ProfileResult *output_profile_result) override;                    \
1923   bool DoBlasGemmWithProfiling(                                                \
1924       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
1925       uint64 m, uint64 n, uint64 k, std::complex<float> alpha,                 \
1926       const DeviceMemory<std::complex<float>> &a, int lda,                     \
1927       const DeviceMemory<std::complex<float>> &b, int ldb,                     \
1928       std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, \
1929       blas::ProfileResult *output_profile_result) override;                    \
1930   bool DoBlasGemmWithProfiling(                                                \
1931       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
1932       uint64 m, uint64 n, uint64 k, std::complex<double> alpha,                \
1933       const DeviceMemory<std::complex<double>> &a, int lda,                    \
1934       const DeviceMemory<std::complex<double>> &b, int ldb,                    \
1935       std::complex<double> beta, DeviceMemory<std::complex<double>> *c,        \
1936       int ldc, blas::ProfileResult *output_profile_result) override;           \
1937   bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms) \
1938       override;                                                                \
1939   bool DoBlasGemmWithAlgorithm(                                                \
1940       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
1941       uint64 m, uint64 n, uint64 k, const HostOrDeviceScalar<int> &alpha,      \
1942       const DeviceMemory<int8> &a, int lda, const DeviceMemory<int8> &b,       \
1943       int ldb, const HostOrDeviceScalar<int> &beta, DeviceMemory<int> *c,      \
1944       int ldc, blas::ComputationType computation_type,                         \
1945       blas::AlgorithmType algorithm,                                           \
1946       blas::ProfileResult *output_profile_result) override;                    \
1947   bool DoBlasGemmWithAlgorithm(                                                \
1948       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
1949       uint64 m, uint64 n, uint64 k,                                            \
1950       const HostOrDeviceScalar<Eigen::half> &alpha,                            \
1951       const DeviceMemory<Eigen::half> &a, int lda,                             \
1952       const DeviceMemory<Eigen::half> &b, int ldb,                             \
1953       const HostOrDeviceScalar<Eigen::half> &beta,                             \
1954       DeviceMemory<Eigen::half> *c, int ldc,                                   \
1955       blas::ComputationType computation_type, blas::AlgorithmType algorithm,   \
1956       blas::ProfileResult *output_profile_result) override;                    \
1957   bool DoBlasGemmWithAlgorithm(                                                \
1958       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
1959       uint64 m, uint64 n, uint64 k, const HostOrDeviceScalar<float> &alpha,    \
1960       const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,     \
1961       int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c,  \
1962       int ldc, blas::ComputationType computation_type,                         \
1963       blas::AlgorithmType algorithm,                                           \
1964       blas::ProfileResult *output_profile_result) override;                    \
1965   bool DoBlasGemmWithAlgorithm(                                                \
1966       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
1967       uint64 m, uint64 n, uint64 k, const HostOrDeviceScalar<double> &alpha,   \
1968       const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,   \
1969       int ldb, const HostOrDeviceScalar<double> &beta,                         \
1970       DeviceMemory<double> *c, int ldc,                                        \
1971       blas::ComputationType computation_type, blas::AlgorithmType algorithm,   \
1972       blas::ProfileResult *output_profile_result) override;                    \
1973   bool DoBlasGemmWithAlgorithm(                                                \
1974       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
1975       uint64 m, uint64 n, uint64 k,                                            \
1976       const HostOrDeviceScalar<std::complex<float>> &alpha,                    \
1977       const DeviceMemory<std::complex<float>> &a, int lda,                     \
1978       const DeviceMemory<std::complex<float>> &b, int ldb,                     \
1979       const HostOrDeviceScalar<std::complex<float>> &beta,                     \
1980       DeviceMemory<std::complex<float>> *c, int ldc,                           \
1981       blas::ComputationType computation_type, blas::AlgorithmType algorithm,   \
1982       blas::ProfileResult *output_profile_result) override;                    \
1983   bool DoBlasGemmWithAlgorithm(                                                \
1984       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
1985       uint64 m, uint64 n, uint64 k,                                            \
1986       const HostOrDeviceScalar<std::complex<double>> &alpha,                   \
1987       const DeviceMemory<std::complex<double>> &a, int lda,                    \
1988       const DeviceMemory<std::complex<double>> &b, int ldb,                    \
1989       const HostOrDeviceScalar<std::complex<double>> &beta,                    \
1990       DeviceMemory<std::complex<double>> *c, int ldc,                          \
1991       blas::ComputationType computation_type, blas::AlgorithmType algorithm,   \
1992       blas::ProfileResult *output_profile_result) override;                    \
1993   bool DoBlasGemmBatched(                                                      \
1994       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
1995       uint64 m, uint64 n, uint64 k, float alpha,                               \
1996       const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,         \
1997       const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb,         \
1998       float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,      \
1999       int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \
2000   bool DoBlasGemmBatched(                                                      \
2001       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2002       uint64 m, uint64 n, uint64 k, float alpha,                               \
2003       const port::ArraySlice<DeviceMemory<float> *> &a, int lda,               \
2004       const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, float beta,   \
2005       const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,               \
2006       int batch_count, ScratchAllocator *scratch_allocator) override;          \
2007   bool DoBlasGemmBatched(                                                      \
2008       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2009       uint64 m, uint64 n, uint64 k, double alpha,                              \
2010       const port::ArraySlice<DeviceMemory<double> *> &a, int lda,              \
2011       const port::ArraySlice<DeviceMemory<double> *> &b, int ldb, double beta, \
2012       const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,              \
2013       int batch_count, ScratchAllocator *scratch_allocator) override;          \
2014   bool DoBlasGemmBatched(                                                      \
2015       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2016       uint64 m, uint64 n, uint64 k, std::complex<float> alpha,                 \
2017       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda, \
2018       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb, \
2019       std::complex<float> beta,                                                \
2020       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc, \
2021       int batch_count, ScratchAllocator *scratch_allocator) override;          \
2022   bool DoBlasGemmBatched(                                                      \
2023       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2024       uint64 m, uint64 n, uint64 k, std::complex<double> alpha,                \
2025       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a,         \
2026       int lda,                                                                 \
2027       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b,         \
2028       int ldb, std::complex<double> beta,                                      \
2029       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c,         \
2030       int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \
2031   bool DoBlasGemmStridedBatched(                                               \
2032       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2033       uint64 m, uint64 n, uint64 k, float alpha,                               \
2034       const DeviceMemory<Eigen::half> &a, int lda, int64 stride_a,             \
2035       const DeviceMemory<Eigen::half> &b, int ldb, int64 stride_b, float beta, \
2036       DeviceMemory<Eigen::half> *c, int ldc, int64 stride_c, int batch_count); \
2037   bool DoBlasGemmStridedBatched(                                               \
2038       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2039       uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, \
2040       int lda, int64 stride_a, const DeviceMemory<float> &b, int ldb,          \
2041       int64 stride_b, float beta, DeviceMemory<float> *c, int ldc,             \
2042       int64 stride_c, int batch_count);                                        \
2043   bool DoBlasGemmStridedBatched(                                               \
2044       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2045       uint64 m, uint64 n, uint64 k, double alpha,                              \
2046       const DeviceMemory<double> &a, int lda, int64 stride_a,                  \
2047       const DeviceMemory<double> &b, int ldb, int64 stride_b, double beta,     \
2048       DeviceMemory<double> *c, int ldc, int64 stride_c, int batch_count);      \
2049   bool DoBlasGemmStridedBatched(                                               \
2050       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2051       uint64 m, uint64 n, uint64 k, std::complex<float> alpha,                 \
2052       const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,     \
2053       const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,     \
2054       std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, \
2055       int64 stride_c, int batch_count);                                        \
2056   bool DoBlasGemmStridedBatched(                                               \
2057       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2058       uint64 m, uint64 n, uint64 k, std::complex<double> alpha,                \
2059       const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,    \
2060       const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,    \
2061       std::complex<double> beta, DeviceMemory<std::complex<double>> *c,        \
2062       int ldc, int64 stride_c, int batch_count);                               \
2063   bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2064                   uint64 m, uint64 n, std::complex<float> alpha,               \
2065                   const DeviceMemory<std::complex<float>> &a, int lda,         \
2066                   const DeviceMemory<std::complex<float>> &b, int ldb,         \
2067                   std::complex<float> beta,                                    \
2068                   DeviceMemory<std::complex<float>> *c, int ldc) override;     \
2069   bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2070                   uint64 m, uint64 n, std::complex<double> alpha,              \
2071                   const DeviceMemory<std::complex<double>> &a, int lda,        \
2072                   const DeviceMemory<std::complex<double>> &b, int ldb,        \
2073                   std::complex<double> beta,                                   \
2074                   DeviceMemory<std::complex<double>> *c, int ldc) override;    \
2075   bool DoBlasHerk(Stream *stream, blas::UpperLower uplo,                       \
2076                   blas::Transpose trans, uint64 n, uint64 k, float alpha,      \
2077                   const DeviceMemory<std::complex<float>> &a, int lda,         \
2078                   float beta, DeviceMemory<std::complex<float>> *c, int ldc)   \
2079       override;                                                                \
2080   bool DoBlasHerk(Stream *stream, blas::UpperLower uplo,                       \
2081                   blas::Transpose trans, uint64 n, uint64 k, double alpha,     \
2082                   const DeviceMemory<std::complex<double>> &a, int lda,        \
2083                   double beta, DeviceMemory<std::complex<double>> *c, int ldc) \
2084       override;                                                                \
2085   bool DoBlasHer2k(                                                            \
2086       Stream *stream, blas::UpperLower uplo, blas::Transpose trans, uint64 n,  \
2087       uint64 k, std::complex<float> alpha,                                     \
2088       const DeviceMemory<std::complex<float>> &a, int lda,                     \
2089       const DeviceMemory<std::complex<float>> &b, int ldb, float beta,         \
2090       DeviceMemory<std::complex<float>> *c, int ldc) override;                 \
2091   bool DoBlasHer2k(                                                            \
2092       Stream *stream, blas::UpperLower uplo, blas::Transpose trans, uint64 n,  \
2093       uint64 k, std::complex<double> alpha,                                    \
2094       const DeviceMemory<std::complex<double>> &a, int lda,                    \
2095       const DeviceMemory<std::complex<double>> &b, int ldb, double beta,       \
2096       DeviceMemory<std::complex<double>> *c, int ldc) override;                \
2097   bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2098                   uint64 m, uint64 n, float alpha,                             \
2099                   const DeviceMemory<float> &a, int lda,                       \
2100                   const DeviceMemory<float> &b, int ldb, float beta,           \
2101                   DeviceMemory<float> *c, int ldc) override;                   \
2102   bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2103                   uint64 m, uint64 n, double alpha,                            \
2104                   const DeviceMemory<double> &a, int lda,                      \
2105                   const DeviceMemory<double> &b, int ldb, double beta,         \
2106                   DeviceMemory<double> *c, int ldc) override;                  \
2107   bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2108                   uint64 m, uint64 n, std::complex<float> alpha,               \
2109                   const DeviceMemory<std::complex<float>> &a, int lda,         \
2110                   const DeviceMemory<std::complex<float>> &b, int ldb,         \
2111                   std::complex<float> beta,                                    \
2112                   DeviceMemory<std::complex<float>> *c, int ldc) override;     \
2113   bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2114                   uint64 m, uint64 n, std::complex<double> alpha,              \
2115                   const DeviceMemory<std::complex<double>> &a, int lda,        \
2116                   const DeviceMemory<std::complex<double>> &b, int ldb,        \
2117                   std::complex<double> beta,                                   \
2118                   DeviceMemory<std::complex<double>> *c, int ldc) override;    \
2119   bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,                       \
2120                   blas::Transpose trans, uint64 n, uint64 k, float alpha,      \
2121                   const DeviceMemory<float> &a, int lda, float beta,           \
2122                   DeviceMemory<float> *c, int ldc) override;                   \
2123   bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,                       \
2124                   blas::Transpose trans, uint64 n, uint64 k, double alpha,     \
2125                   const DeviceMemory<double> &a, int lda, double beta,         \
2126                   DeviceMemory<double> *c, int ldc) override;                  \
2127   bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,                       \
2128                   blas::Transpose trans, uint64 n, uint64 k,                   \
2129                   std::complex<float> alpha,                                   \
2130                   const DeviceMemory<std::complex<float>> &a, int lda,         \
2131                   std::complex<float> beta,                                    \
2132                   DeviceMemory<std::complex<float>> *c, int ldc) override;     \
2133   bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,                       \
2134                   blas::Transpose trans, uint64 n, uint64 k,                   \
2135                   std::complex<double> alpha,                                  \
2136                   const DeviceMemory<std::complex<double>> &a, int lda,        \
2137                   std::complex<double> beta,                                   \
2138                   DeviceMemory<std::complex<double>> *c, int ldc) override;    \
2139   bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,                      \
2140                    blas::Transpose trans, uint64 n, uint64 k, float alpha,     \
2141                    const DeviceMemory<float> &a, int lda,                      \
2142                    const DeviceMemory<float> &b, int ldb, float beta,          \
2143                    DeviceMemory<float> *c, int ldc) override;                  \
2144   bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,                      \
2145                    blas::Transpose trans, uint64 n, uint64 k, double alpha,    \
2146                    const DeviceMemory<double> &a, int lda,                     \
2147                    const DeviceMemory<double> &b, int ldb, double beta,        \
2148                    DeviceMemory<double> *c, int ldc) override;                 \
2149   bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,                      \
2150                    blas::Transpose trans, uint64 n, uint64 k,                  \
2151                    std::complex<float> alpha,                                  \
2152                    const DeviceMemory<std::complex<float>> &a, int lda,        \
2153                    const DeviceMemory<std::complex<float>> &b, int ldb,        \
2154                    std::complex<float> beta,                                   \
2155                    DeviceMemory<std::complex<float>> *c, int ldc) override;    \
2156   bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,                      \
2157                    blas::Transpose trans, uint64 n, uint64 k,                  \
2158                    std::complex<double> alpha,                                 \
2159                    const DeviceMemory<std::complex<double>> &a, int lda,       \
2160                    const DeviceMemory<std::complex<double>> &b, int ldb,       \
2161                    std::complex<double> beta,                                  \
2162                    DeviceMemory<std::complex<double>> *c, int ldc) override;   \
2163   bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2164                   blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
2165                   uint64 n, float alpha, const DeviceMemory<float> &a,         \
2166                   int lda, DeviceMemory<float> *b, int ldb) override;          \
2167   bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2168                   blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
2169                   uint64 n, double alpha, const DeviceMemory<double> &a,       \
2170                   int lda, DeviceMemory<double> *b, int ldb) override;         \
2171   bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2172                   blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
2173                   uint64 n, std::complex<float> alpha,                         \
2174                   const DeviceMemory<std::complex<float>> &a, int lda,         \
2175                   DeviceMemory<std::complex<float>> *b, int ldb) override;     \
2176   bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2177                   blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
2178                   uint64 n, std::complex<double> alpha,                        \
2179                   const DeviceMemory<std::complex<double>> &a, int lda,        \
2180                   DeviceMemory<std::complex<double>> *b, int ldb) override;    \
2181   bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2182                   blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
2183                   uint64 n, float alpha, const DeviceMemory<float> &a,         \
2184                   int lda, DeviceMemory<float> *b, int ldb) override;          \
2185   bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2186                   blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
2187                   uint64 n, double alpha, const DeviceMemory<double> &a,       \
2188                   int lda, DeviceMemory<double> *b, int ldb) override;         \
2189   bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2190                   blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
2191                   uint64 n, std::complex<float> alpha,                         \
2192                   const DeviceMemory<std::complex<float>> &a, int lda,         \
2193                   DeviceMemory<std::complex<float>> *b, int ldb) override;     \
2194   bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2195                   blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
2196                   uint64 n, std::complex<double> alpha,                        \
2197                   const DeviceMemory<std::complex<double>> &a, int lda,        \
2198                   DeviceMemory<std::complex<double>> *b, int ldb) override;    \
2199   port::Status GetVersion(string *version) override;
2200 
2201 }  // namespace blas
2202 }  // namespace stream_executor
2203 
2204 #endif  // TENSORFLOW_STREAM_EXECUTOR_BLAS_H_
2205