1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Exposes the family of BLAS routines as pre-canned high performance calls for 17 // use in conjunction with the StreamExecutor abstraction. 18 // 19 // Note that this interface is optionally supported by platforms; see 20 // StreamExecutor::SupportsBlas() for details. 21 // 22 // This abstraction makes it simple to entrain BLAS operations on GPU data into 23 // a Stream -- users typically will not use this API directly, but will use the 24 // Stream builder methods to entrain these operations "under the hood". For 25 // example: 26 // 27 // DeviceMemory<float> x = stream_exec->AllocateArray<float>(1024); 28 // DeviceMemory<float> y = stream_exec->AllocateArray<float>(1024); 29 // // ... populate x and y ... 30 // Stream stream{stream_exec}; 31 // stream 32 // .Init() 33 // .ThenBlasAxpy(1024, 5.5, x, 1, &y, 1); 34 // SE_CHECK_OK(stream.BlockHostUntilDone()); 35 // 36 // By using stream operations in this manner the user can easily intermix custom 37 // kernel launches (via StreamExecutor::ThenLaunch()) with these pre-canned BLAS 38 // routines. 39 40 #ifndef TENSORFLOW_STREAM_EXECUTOR_BLAS_H_ 41 #define TENSORFLOW_STREAM_EXECUTOR_BLAS_H_ 42 43 #include <complex> 44 #include <vector> 45 46 #include "tensorflow/stream_executor/host_or_device_scalar.h" 47 #include "tensorflow/stream_executor/lib/array_slice.h" 48 #include "tensorflow/stream_executor/lib/statusor.h" 49 #include "tensorflow/stream_executor/platform/port.h" 50 51 namespace Eigen { 52 struct half; 53 } // namespace Eigen 54 55 namespace stream_executor { 56 57 class Stream; 58 class ScratchAllocator; 59 60 template <typename ElemT> 61 class DeviceMemory; 62 63 namespace blas { 64 65 // Specifies whether the input matrix will be transposed or 66 // transposed+conjugated before any BLAS operations. 67 enum class Transpose { kNoTranspose, kTranspose, kConjugateTranspose }; 68 69 // Returns a name for t. 70 string TransposeString(Transpose t); 71 72 // Specifies whether the upper or lower triangular part of a 73 // symmetric/Hermitian matrix is used. 74 enum class UpperLower { kUpper, kLower }; 75 76 // Returns a name for ul. 77 string UpperLowerString(UpperLower ul); 78 79 // Specifies whether a matrix is unit triangular. 80 enum class Diagonal { kUnit, kNonUnit }; 81 82 // Returns a name for d. 83 string DiagonalString(Diagonal d); 84 85 // Specifies whether a Hermitian matrix appears on the left or right in 86 // operation. 87 enum class Side { kLeft, kRight }; 88 89 // Returns a name for s. 90 string SideString(Side s); 91 92 // Type with which intermediate computations of a blas routine are performed. 93 // 94 // Some blas calls can perform computations with a type that's different than 95 // the type of their inputs/outputs. This lets you e.g. multiply two matrices 96 // of int8s using float32s to store the matmul's intermediate values. 97 enum class ComputationType { 98 kF16, // 16-bit floating-point 99 kF32, // 32-bit floating-point 100 kF64, // 64-bit floating-point 101 kI32, // 32-bit integer 102 kComplexF32, // Complex number comprised of two f32s. 103 kComplexF64, // Complex number comprised of two f64s. 104 }; 105 106 // Converts a ComputationType to a string. 107 string ComputationTypeString(ComputationType ty); 108 109 std::ostream &operator<<(std::ostream &os, ComputationType ty); 110 111 // Opaque identifier for an "algorithm" used by a blas routine. This functions 112 // as a hint to the blas library. 113 typedef int64 AlgorithmType; 114 constexpr AlgorithmType kDefaultAlgorithm = -1; 115 constexpr AlgorithmType kDefaultBlasGemm = -2; 116 constexpr AlgorithmType kDefaultBlasGemv = -3; 117 constexpr AlgorithmType kNoAlgorithm = -4; 118 119 // blas uses -1 to represent the default algorithm. This happens to match up 120 // with the CUBLAS_GEMM_DFALT constant, so cuda_blas.cc is using static_cast 121 // to convert from AlgorithmType to cublasGemmAlgo_t, and uses a static_assert 122 // to ensure that this assumption does not break. 123 // If another blas implementation uses a different value for the default 124 // algorithm, then it needs to convert kDefaultGemmAlgo to that value 125 // (e.g. via a function called ToWhateverGemmAlgo). 126 constexpr AlgorithmType kDefaultGemmAlgo = -1; 127 128 // Describes the result of a performance experiment, usually timing the speed of 129 // a particular AlgorithmType. 130 // 131 // If the call we were benchmarking failed (a common occurrence; not all 132 // algorithms are valid for all calls), is_valid() will be false. 133 class ProfileResult { 134 public: is_valid()135 bool is_valid() const { return is_valid_; } set_is_valid(bool val)136 void set_is_valid(bool val) { is_valid_ = val; } algorithm()137 AlgorithmType algorithm() const { return algorithm_; } set_algorithm(AlgorithmType val)138 void set_algorithm(AlgorithmType val) { algorithm_ = val; } elapsed_time_in_ms()139 float elapsed_time_in_ms() const { return elapsed_time_in_ms_; } set_elapsed_time_in_ms(float val)140 void set_elapsed_time_in_ms(float val) { elapsed_time_in_ms_ = val; } 141 142 private: 143 bool is_valid_ = false; 144 AlgorithmType algorithm_ = kDefaultAlgorithm; 145 float elapsed_time_in_ms_ = std::numeric_limits<float>::max(); 146 }; 147 148 class AlgorithmConfig { 149 public: AlgorithmConfig()150 AlgorithmConfig() : algorithm_(kDefaultAlgorithm) {} AlgorithmConfig(AlgorithmType algorithm)151 explicit AlgorithmConfig(AlgorithmType algorithm) : algorithm_(algorithm) {} algorithm()152 AlgorithmType algorithm() const { return algorithm_; } set_algorithm(AlgorithmType val)153 void set_algorithm(AlgorithmType val) { algorithm_ = val; } 154 bool operator==(const AlgorithmConfig &other) const { 155 return this->algorithm_ == other.algorithm_; 156 } 157 bool operator!=(const AlgorithmConfig &other) const { 158 return !(*this == other); 159 } 160 string ToString() const; 161 162 private: 163 AlgorithmType algorithm_; 164 }; 165 166 // BLAS support interface -- this can be derived from a GPU executor when the 167 // underlying platform has an BLAS library implementation available. See 168 // StreamExecutor::AsBlas(). 169 // 170 // Thread-hostile: CUDA associates a CUDA-context with a particular thread in 171 // the system. Any operation that a user attempts to perform by enqueueing BLAS 172 // operations on a thread not-associated with the CUDA-context has unknown 173 // behavior at the current time; see b/13176597 174 class BlasSupport { 175 public: ~BlasSupport()176 virtual ~BlasSupport() {} 177 178 // Computes the sum of magnitudes of the vector elements. 179 // result <- |Re x(1)| + |Im x(1)| + |Re x(2)| + |Im x(2)|+ ... + |Re x(n)| 180 // + |Im x(n)|. 181 // Note that Im x(i) = 0 for real types float/double. 182 virtual bool DoBlasAsum(Stream *stream, uint64 elem_count, 183 const DeviceMemory<float> &x, int incx, 184 DeviceMemory<float> *result) = 0; 185 virtual bool DoBlasAsum(Stream *stream, uint64 elem_count, 186 const DeviceMemory<double> &x, int incx, 187 DeviceMemory<double> *result) = 0; 188 virtual bool DoBlasAsum(Stream *stream, uint64 elem_count, 189 const DeviceMemory<std::complex<float>> &x, int incx, 190 DeviceMemory<float> *result) = 0; 191 virtual bool DoBlasAsum(Stream *stream, uint64 elem_count, 192 const DeviceMemory<std::complex<double>> &x, int incx, 193 DeviceMemory<double> *result) = 0; 194 195 // Performs a BLAS y <- ax+y operation. 196 virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha, 197 const DeviceMemory<float> &x, int incx, 198 DeviceMemory<float> *y, int incy) = 0; 199 virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha, 200 const DeviceMemory<double> &x, int incx, 201 DeviceMemory<double> *y, int incy) = 0; 202 virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count, 203 std::complex<float> alpha, 204 const DeviceMemory<std::complex<float>> &x, int incx, 205 DeviceMemory<std::complex<float>> *y, int incy) = 0; 206 virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count, 207 std::complex<double> alpha, 208 const DeviceMemory<std::complex<double>> &x, int incx, 209 DeviceMemory<std::complex<double>> *y, int incy) = 0; 210 211 // Copies vector to another vector: y <- x. 212 virtual bool DoBlasCopy(Stream *stream, uint64 elem_count, 213 const DeviceMemory<float> &x, int incx, 214 DeviceMemory<float> *y, int incy) = 0; 215 virtual bool DoBlasCopy(Stream *stream, uint64 elem_count, 216 const DeviceMemory<double> &x, int incx, 217 DeviceMemory<double> *y, int incy) = 0; 218 virtual bool DoBlasCopy(Stream *stream, uint64 elem_count, 219 const DeviceMemory<std::complex<float>> &x, int incx, 220 DeviceMemory<std::complex<float>> *y, int incy) = 0; 221 virtual bool DoBlasCopy(Stream *stream, uint64 elem_count, 222 const DeviceMemory<std::complex<double>> &x, int incx, 223 DeviceMemory<std::complex<double>> *y, int incy) = 0; 224 225 // Performs a BLAS dot product result <- x . y. 226 virtual bool DoBlasDot(Stream *stream, uint64 elem_count, 227 const DeviceMemory<float> &x, int incx, 228 const DeviceMemory<float> &y, int incy, 229 DeviceMemory<float> *result) = 0; 230 virtual bool DoBlasDot(Stream *stream, uint64 elem_count, 231 const DeviceMemory<double> &x, int incx, 232 const DeviceMemory<double> &y, int incy, 233 DeviceMemory<double> *result) = 0; 234 235 // Performs a BLAS dot product result <- conj(x) . y for complex types. 236 virtual bool DoBlasDotc(Stream *stream, uint64 elem_count, 237 const DeviceMemory<std::complex<float>> &x, int incx, 238 const DeviceMemory<std::complex<float>> &y, int incy, 239 DeviceMemory<std::complex<float>> *result) = 0; 240 virtual bool DoBlasDotc(Stream *stream, uint64 elem_count, 241 const DeviceMemory<std::complex<double>> &x, int incx, 242 const DeviceMemory<std::complex<double>> &y, int incy, 243 DeviceMemory<std::complex<double>> *result) = 0; 244 245 // Performs a BLAS dot product result <- x . y for complex types. Note that 246 // x is unconjugated in this routine. 247 virtual bool DoBlasDotu(Stream *stream, uint64 elem_count, 248 const DeviceMemory<std::complex<float>> &x, int incx, 249 const DeviceMemory<std::complex<float>> &y, int incy, 250 DeviceMemory<std::complex<float>> *result) = 0; 251 virtual bool DoBlasDotu(Stream *stream, uint64 elem_count, 252 const DeviceMemory<std::complex<double>> &x, int incx, 253 const DeviceMemory<std::complex<double>> &y, int incy, 254 DeviceMemory<std::complex<double>> *result) = 0; 255 256 // Computes the Euclidean norm of a vector: result <- ||x||. 257 // See the following link for more information of Euclidean norm: 258 // http://en.wikipedia.org/wiki/Norm_(mathematics)#Euclidean_norm 259 virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count, 260 const DeviceMemory<float> &x, int incx, 261 DeviceMemory<float> *result) = 0; 262 virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count, 263 const DeviceMemory<double> &x, int incx, 264 DeviceMemory<double> *result) = 0; 265 virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count, 266 const DeviceMemory<std::complex<float>> &x, int incx, 267 DeviceMemory<float> *result) = 0; 268 virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count, 269 const DeviceMemory<std::complex<double>> &x, int incx, 270 DeviceMemory<double> *result) = 0; 271 272 // Performs rotation of points in the plane: 273 // x(i) = c*x(i) + s*y(i) 274 // y(i) = c*y(i) - s*x(i). 275 virtual bool DoBlasRot(Stream *stream, uint64 elem_count, 276 DeviceMemory<float> *x, int incx, 277 DeviceMemory<float> *y, int incy, float c, 278 float s) = 0; 279 virtual bool DoBlasRot(Stream *stream, uint64 elem_count, 280 DeviceMemory<double> *x, int incx, 281 DeviceMemory<double> *y, int incy, double c, 282 double s) = 0; 283 virtual bool DoBlasRot(Stream *stream, uint64 elem_count, 284 DeviceMemory<std::complex<float>> *x, int incx, 285 DeviceMemory<std::complex<float>> *y, int incy, 286 float c, float s) = 0; 287 virtual bool DoBlasRot(Stream *stream, uint64 elem_count, 288 DeviceMemory<std::complex<double>> *x, int incx, 289 DeviceMemory<std::complex<double>> *y, int incy, 290 double c, double s) = 0; 291 292 // Computes the parameters for a Givens rotation. 293 // Given the Cartesian coordinates (a, b) of a point, these routines return 294 // the parameters c, s, r, and z associated with the Givens rotation. The 295 // parameters c and s define a unitary matrix such that: 296 // 297 // | c s |.| a | = | r | 298 // | -s c | | b | | 0 | 299 // 300 // The parameter z is defined such that if |a| > |b|, z is s; otherwise if 301 // c is not 0 z is 1/c; otherwise z is 1. 302 virtual bool DoBlasRotg(Stream *stream, DeviceMemory<float> *a, 303 DeviceMemory<float> *b, DeviceMemory<float> *c, 304 DeviceMemory<float> *s) = 0; 305 virtual bool DoBlasRotg(Stream *stream, DeviceMemory<double> *a, 306 DeviceMemory<double> *b, DeviceMemory<double> *c, 307 DeviceMemory<double> *s) = 0; 308 virtual bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<float>> *a, 309 DeviceMemory<std::complex<float>> *b, 310 DeviceMemory<float> *c, 311 DeviceMemory<std::complex<float>> *s) = 0; 312 virtual bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<double>> *a, 313 DeviceMemory<std::complex<double>> *b, 314 DeviceMemory<double> *c, 315 DeviceMemory<std::complex<double>> *s) = 0; 316 317 // Performs modified Givens rotation of points in the plane. 318 // Given two vectors x and y, each vector element of these vectors is replaced 319 // as follows: 320 // 321 // | x(i) | = H | x(i) | 322 // | y(i) | | y(i) | 323 // 324 // for i=1 to n, where H is a modified Givens transformation matrix whose 325 // values are stored in the param[1] through param[4] array. 326 // For more information please Google this routine. 327 virtual bool DoBlasRotm(Stream *stream, uint64 elem_count, 328 DeviceMemory<float> *x, int incx, 329 DeviceMemory<float> *y, int incy, 330 const DeviceMemory<float> ¶m) = 0; 331 virtual bool DoBlasRotm(Stream *stream, uint64 elem_count, 332 DeviceMemory<double> *x, int incx, 333 DeviceMemory<double> *y, int incy, 334 const DeviceMemory<double> ¶m) = 0; 335 336 // Computes the parameters for a modified Givens rotation. 337 // Given Cartesian coordinates (x1, y1) of an input vector, these routines 338 // compute the components of a modified Givens transformation matrix H that 339 // zeros the y-component of the resulting vector: 340 // 341 // | x1 | = H | x1 * sqrt(d1) | 342 // | 0 | | y1 * sqrt(d1) | 343 // 344 // For more information please Google this routine. 345 virtual bool DoBlasRotmg(Stream *stream, DeviceMemory<float> *d1, 346 DeviceMemory<float> *d2, DeviceMemory<float> *x1, 347 const DeviceMemory<float> &y1, 348 DeviceMemory<float> *param) = 0; 349 virtual bool DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1, 350 DeviceMemory<double> *d2, DeviceMemory<double> *x1, 351 const DeviceMemory<double> &y1, 352 DeviceMemory<double> *param) = 0; 353 354 // Computes the product of a vector by a scalar: x <- a*x. 355 virtual bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha, 356 DeviceMemory<float> *x, int incx) = 0; 357 virtual bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha, 358 DeviceMemory<double> *x, int incx) = 0; 359 virtual bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha, 360 DeviceMemory<std::complex<float>> *x, int incx) = 0; 361 virtual bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha, 362 DeviceMemory<std::complex<double>> *x, int incx) = 0; 363 virtual bool DoBlasScal(Stream *stream, uint64 elem_count, 364 std::complex<float> alpha, 365 DeviceMemory<std::complex<float>> *x, int incx) = 0; 366 virtual bool DoBlasScal(Stream *stream, uint64 elem_count, 367 std::complex<double> alpha, 368 DeviceMemory<std::complex<double>> *x, int incx) = 0; 369 370 // Swaps a vector with another vector. 371 virtual bool DoBlasSwap(Stream *stream, uint64 elem_count, 372 DeviceMemory<float> *x, int incx, 373 DeviceMemory<float> *y, int incy) = 0; 374 virtual bool DoBlasSwap(Stream *stream, uint64 elem_count, 375 DeviceMemory<double> *x, int incx, 376 DeviceMemory<double> *y, int incy) = 0; 377 virtual bool DoBlasSwap(Stream *stream, uint64 elem_count, 378 DeviceMemory<std::complex<float>> *x, int incx, 379 DeviceMemory<std::complex<float>> *y, int incy) = 0; 380 virtual bool DoBlasSwap(Stream *stream, uint64 elem_count, 381 DeviceMemory<std::complex<double>> *x, int incx, 382 DeviceMemory<std::complex<double>> *y, int incy) = 0; 383 384 // Finds the index of the element with maximum absolute value. 385 virtual bool DoBlasIamax(Stream *stream, uint64 elem_count, 386 const DeviceMemory<float> &x, int incx, 387 DeviceMemory<int> *result) = 0; 388 virtual bool DoBlasIamax(Stream *stream, uint64 elem_count, 389 const DeviceMemory<double> &x, int incx, 390 DeviceMemory<int> *result) = 0; 391 virtual bool DoBlasIamax(Stream *stream, uint64 elem_count, 392 const DeviceMemory<std::complex<float>> &x, int incx, 393 DeviceMemory<int> *result) = 0; 394 virtual bool DoBlasIamax(Stream *stream, uint64 elem_count, 395 const DeviceMemory<std::complex<double>> &x, 396 int incx, DeviceMemory<int> *result) = 0; 397 398 // Finds the index of the element with minimum absolute value. 399 virtual bool DoBlasIamin(Stream *stream, uint64 elem_count, 400 const DeviceMemory<float> &x, int incx, 401 DeviceMemory<int> *result) = 0; 402 virtual bool DoBlasIamin(Stream *stream, uint64 elem_count, 403 const DeviceMemory<double> &x, int incx, 404 DeviceMemory<int> *result) = 0; 405 virtual bool DoBlasIamin(Stream *stream, uint64 elem_count, 406 const DeviceMemory<std::complex<float>> &x, int incx, 407 DeviceMemory<int> *result) = 0; 408 virtual bool DoBlasIamin(Stream *stream, uint64 elem_count, 409 const DeviceMemory<std::complex<double>> &x, 410 int incx, DeviceMemory<int> *result) = 0; 411 412 // Computes a matrix-vector product using a general band matrix: 413 // 414 // y <- alpha * a * x + beta * y, 415 // or 416 // y <- alpha * a' * x + beta * y, 417 // or 418 // y <- alpha * conj(a') * x + beta * y, 419 // 420 // alpha and beta are scalars; a is an m-by-n general band matrix, with kl 421 // sub-diagonals and ku super-diagonals; x is a vector with 422 // n(trans==kNoTranspose)/m(otherwise) elements; 423 // y is a vector with m(trans==kNoTranspose)/n(otherwise) elements. 424 virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, 425 uint64 n, uint64 kl, uint64 ku, float alpha, 426 const DeviceMemory<float> &a, int lda, 427 const DeviceMemory<float> &x, int incx, float beta, 428 DeviceMemory<float> *y, int incy) = 0; 429 virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, 430 uint64 n, uint64 kl, uint64 ku, double alpha, 431 const DeviceMemory<double> &a, int lda, 432 const DeviceMemory<double> &x, int incx, double beta, 433 DeviceMemory<double> *y, int incy) = 0; 434 virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, 435 uint64 n, uint64 kl, uint64 ku, 436 std::complex<float> alpha, 437 const DeviceMemory<std::complex<float>> &a, int lda, 438 const DeviceMemory<std::complex<float>> &x, int incx, 439 std::complex<float> beta, 440 DeviceMemory<std::complex<float>> *y, int incy) = 0; 441 virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, 442 uint64 n, uint64 kl, uint64 ku, 443 std::complex<double> alpha, 444 const DeviceMemory<std::complex<double>> &a, int lda, 445 const DeviceMemory<std::complex<double>> &x, int incx, 446 std::complex<double> beta, 447 DeviceMemory<std::complex<double>> *y, int incy) = 0; 448 449 // Computes a matrix-vector product using a general matrix. 450 // 451 // y <- alpha * a * x + beta * y, 452 // or 453 // y <- alpha * a' * x + beta * y, 454 // or 455 // y <- alpha * conj(a') * x + beta * y, 456 // 457 // alpha and beta are scalars; a is an m-by-n general matrix; x is a vector 458 // with n(trans==kNoTranspose)/m(otherwise) elements; 459 // y is a vector with m(trans==kNoTranspose)/n(otherwise) elements. 460 virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, 461 uint64 n, float alpha, const DeviceMemory<float> &a, 462 int lda, const DeviceMemory<float> &x, int incx, 463 float beta, DeviceMemory<float> *y, int incy) = 0; 464 virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, 465 uint64 n, double alpha, const DeviceMemory<double> &a, 466 int lda, const DeviceMemory<double> &x, int incx, 467 double beta, DeviceMemory<double> *y, int incy) = 0; 468 virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, 469 uint64 n, std::complex<float> alpha, 470 const DeviceMemory<std::complex<float>> &a, int lda, 471 const DeviceMemory<std::complex<float>> &x, int incx, 472 std::complex<float> beta, 473 DeviceMemory<std::complex<float>> *y, int incy) = 0; 474 virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, 475 uint64 n, std::complex<double> alpha, 476 const DeviceMemory<std::complex<double>> &a, int lda, 477 const DeviceMemory<std::complex<double>> &x, int incx, 478 std::complex<double> beta, 479 DeviceMemory<std::complex<double>> *y, int incy) = 0; 480 481 virtual bool DoBlasGemvWithProfiling( 482 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha, 483 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x, 484 int incx, float beta, DeviceMemory<float> *y, int incy, 485 ProfileResult *output_profile_result) = 0; 486 virtual bool DoBlasGemvWithProfiling( 487 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha, 488 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x, 489 int incx, double beta, DeviceMemory<double> *y, int incy, 490 ProfileResult *output_profile_result) = 0; 491 virtual bool DoBlasGemvWithProfiling( 492 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, 493 std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a, 494 int lda, const DeviceMemory<std::complex<float>> &x, int incx, 495 std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy, 496 ProfileResult *output_profile_result) = 0; 497 virtual bool DoBlasGemvWithProfiling( 498 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, 499 std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a, 500 int lda, const DeviceMemory<std::complex<double>> &x, int incx, 501 std::complex<double> beta, DeviceMemory<std::complex<double>> *y, 502 int incy, ProfileResult *output_profile_result) = 0; 503 504 // Performs a rank-1 update of a general matrix. 505 // 506 // a <- alpha * x * y' + a, 507 // 508 // alpha is a scalar; x is an m-element vector; y is an n-element vector; a is 509 // an m-by-n general matrix. 510 virtual bool DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha, 511 const DeviceMemory<float> &x, int incx, 512 const DeviceMemory<float> &y, int incy, 513 DeviceMemory<float> *a, int lda) = 0; 514 virtual bool DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha, 515 const DeviceMemory<double> &x, int incx, 516 const DeviceMemory<double> &y, int incy, 517 DeviceMemory<double> *a, int lda) = 0; 518 519 // Performs a rank-1 update (conjugated) of a general matrix. 520 // 521 // a <- alpha * x * conj(y') + a, 522 // 523 // alpha is a scalar; x is an m-element vector; y is an n-element vector; a is 524 // an m-by-n general matrix. 525 virtual bool DoBlasGerc(Stream *stream, uint64 m, uint64 n, 526 std::complex<float> alpha, 527 const DeviceMemory<std::complex<float>> &x, int incx, 528 const DeviceMemory<std::complex<float>> &y, int incy, 529 DeviceMemory<std::complex<float>> *a, int lda) = 0; 530 virtual bool DoBlasGerc(Stream *stream, uint64 m, uint64 n, 531 std::complex<double> alpha, 532 const DeviceMemory<std::complex<double>> &x, int incx, 533 const DeviceMemory<std::complex<double>> &y, int incy, 534 DeviceMemory<std::complex<double>> *a, int lda) = 0; 535 536 // Performs a rank-1 update (unconjugated) of a general matrix. 537 // 538 // a <- alpha * x * y' + a, 539 // 540 // alpha is a scalar; x is an m-element vector; y is an n-element vector; a is 541 // an m-by-n general matrix. 542 virtual bool DoBlasGeru(Stream *stream, uint64 m, uint64 n, 543 std::complex<float> alpha, 544 const DeviceMemory<std::complex<float>> &x, int incx, 545 const DeviceMemory<std::complex<float>> &y, int incy, 546 DeviceMemory<std::complex<float>> *a, int lda) = 0; 547 virtual bool DoBlasGeru(Stream *stream, uint64 m, uint64 n, 548 std::complex<double> alpha, 549 const DeviceMemory<std::complex<double>> &x, int incx, 550 const DeviceMemory<std::complex<double>> &y, int incy, 551 DeviceMemory<std::complex<double>> *a, int lda) = 0; 552 553 // Computes a matrix-vector product using a Hermitian band matrix. 554 // 555 // y <- alpha * a * x + beta * y, 556 // 557 // alpha and beta are scalars; a is an n-by-n Hermitian band matrix, with k 558 // super-diagonals; x and y are n-element vectors. 559 virtual bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, 560 uint64 k, std::complex<float> alpha, 561 const DeviceMemory<std::complex<float>> &a, int lda, 562 const DeviceMemory<std::complex<float>> &x, int incx, 563 std::complex<float> beta, 564 DeviceMemory<std::complex<float>> *y, int incy) = 0; 565 virtual bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, 566 uint64 k, std::complex<double> alpha, 567 const DeviceMemory<std::complex<double>> &a, int lda, 568 const DeviceMemory<std::complex<double>> &x, int incx, 569 std::complex<double> beta, 570 DeviceMemory<std::complex<double>> *y, int incy) = 0; 571 572 // Computes a matrix-vector product using a Hermitian matrix. 573 // 574 // y <- alpha * a * x + beta * y, 575 // 576 // alpha and beta are scalars; a is an n-by-n Hermitian matrix; x and y are 577 // n-element vectors. 578 virtual bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n, 579 std::complex<float> alpha, 580 const DeviceMemory<std::complex<float>> &a, int lda, 581 const DeviceMemory<std::complex<float>> &x, int incx, 582 std::complex<float> beta, 583 DeviceMemory<std::complex<float>> *y, int incy) = 0; 584 virtual bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n, 585 std::complex<double> alpha, 586 const DeviceMemory<std::complex<double>> &a, int lda, 587 const DeviceMemory<std::complex<double>> &x, int incx, 588 std::complex<double> beta, 589 DeviceMemory<std::complex<double>> *y, int incy) = 0; 590 591 // Performs a rank-1 update of a Hermitian matrix. 592 // 593 // a <- alpha * x * conj(x') + a, 594 // 595 // alpha is a scalar; x is an n-element vector; a is an n-by-n Hermitian 596 // matrix. 597 virtual bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, 598 float alpha, 599 const DeviceMemory<std::complex<float>> &x, int incx, 600 DeviceMemory<std::complex<float>> *a, int lda) = 0; 601 virtual bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, 602 double alpha, 603 const DeviceMemory<std::complex<double>> &x, int incx, 604 DeviceMemory<std::complex<double>> *a, int lda) = 0; 605 606 // Performs a rank-2 update of a Hermitian matrix. 607 // 608 // a <- alpha * x * conj(x') + conj(alpha) * y * conj(x') + a, 609 // 610 // alpha is a scalar; x and y are n-element vectors; a is an n-by-n Hermitian 611 // matrix. 612 virtual bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n, 613 std::complex<float> alpha, 614 const DeviceMemory<std::complex<float>> &x, int incx, 615 const DeviceMemory<std::complex<float>> &y, int incy, 616 DeviceMemory<std::complex<float>> *a, int lda) = 0; 617 virtual bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n, 618 std::complex<double> alpha, 619 const DeviceMemory<std::complex<double>> &x, int incx, 620 const DeviceMemory<std::complex<double>> &y, int incy, 621 DeviceMemory<std::complex<double>> *a, int lda) = 0; 622 623 // Computes a matrix-vector product using a Hermitian packed matrix. 624 // 625 // y <- alpha * a * x + beta * y, 626 // 627 // alpha and beta are scalars; a is an n-by-n Hermitian matrix, supplied in 628 // packed form; x and y are n-element vectors. 629 virtual bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n, 630 std::complex<float> alpha, 631 const DeviceMemory<std::complex<float>> &ap, 632 const DeviceMemory<std::complex<float>> &x, int incx, 633 std::complex<float> beta, 634 DeviceMemory<std::complex<float>> *y, int incy) = 0; 635 virtual bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n, 636 std::complex<double> alpha, 637 const DeviceMemory<std::complex<double>> &ap, 638 const DeviceMemory<std::complex<double>> &x, int incx, 639 std::complex<double> beta, 640 DeviceMemory<std::complex<double>> *y, int incy) = 0; 641 642 // Performs a rank-1 update of a Hermitian packed matrix. 643 // 644 // a <- alpha * x * conj(x') + a, 645 // 646 // alpha is a scalar; x is an n-element vector; a is an n-by-n Hermitian 647 // matrix, supplied in packed form. 648 virtual bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, 649 float alpha, 650 const DeviceMemory<std::complex<float>> &x, int incx, 651 DeviceMemory<std::complex<float>> *ap) = 0; 652 virtual bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, 653 double alpha, 654 const DeviceMemory<std::complex<double>> &x, int incx, 655 DeviceMemory<std::complex<double>> *ap) = 0; 656 657 // Performs a rank-2 update of a Hermitian packed matrix. 658 // 659 // a <- alpha * x * conj(x') + conj(alpha) * y * conj(x') + a, 660 // 661 // alpha is a scalar; x and y are n-element vectors; a is an n-by-n Hermitian 662 // matrix, supplied in packed form. 663 virtual bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n, 664 std::complex<float> alpha, 665 const DeviceMemory<std::complex<float>> &x, int incx, 666 const DeviceMemory<std::complex<float>> &y, int incy, 667 DeviceMemory<std::complex<float>> *ap) = 0; 668 virtual bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n, 669 std::complex<double> alpha, 670 const DeviceMemory<std::complex<double>> &x, int incx, 671 const DeviceMemory<std::complex<double>> &y, int incy, 672 DeviceMemory<std::complex<double>> *ap) = 0; 673 674 // Computes a matrix-vector product using a symmetric band matrix. 675 // 676 // y <- alpha * a * x + beta * y, 677 // 678 // alpha and beta are scalars; a is an n-by-n symmetric band matrix, with k 679 // super-diagonals; x and y are n-element vectors. 680 virtual bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, 681 uint64 k, float alpha, const DeviceMemory<float> &a, 682 int lda, const DeviceMemory<float> &x, int incx, 683 float beta, DeviceMemory<float> *y, int incy) = 0; 684 virtual bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, 685 uint64 k, double alpha, const DeviceMemory<double> &a, 686 int lda, const DeviceMemory<double> &x, int incx, 687 double beta, DeviceMemory<double> *y, int incy) = 0; 688 689 // Computes a matrix-vector product using a symmetric packed matrix. 690 // 691 // y <- alpha * a * x + beta * y, 692 // 693 // alpha and beta are scalars; a is an n-by-n symmetric matrix, supplied in 694 // packed form; x and y are n-element vectors. 695 virtual bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n, 696 float alpha, const DeviceMemory<float> &ap, 697 const DeviceMemory<float> &x, int incx, float beta, 698 DeviceMemory<float> *y, int incy) = 0; 699 virtual bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n, 700 double alpha, const DeviceMemory<double> &ap, 701 const DeviceMemory<double> &x, int incx, double beta, 702 DeviceMemory<double> *y, int incy) = 0; 703 704 // Performs a rank-1 update of a symmetric packed matrix. 705 // 706 // a <- alpha * x * x' + a, 707 // 708 // alpha is a scalar; x is an n-element vector; a is an n-by-n symmetric 709 // matrix, supplied in packed form. 710 virtual bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, 711 float alpha, const DeviceMemory<float> &x, int incx, 712 DeviceMemory<float> *ap) = 0; 713 virtual bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, 714 double alpha, const DeviceMemory<double> &x, int incx, 715 DeviceMemory<double> *ap) = 0; 716 717 // Performs a rank-2 update of a symmetric packed matrix. 718 // 719 // a <- alpha * x * x' + alpha * y * x' + a, 720 // 721 // alpha is a scalar; x and y are n-element vectors; a is an n-by-n symmetric 722 // matrix, supplied in packed form. 723 virtual bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n, 724 float alpha, const DeviceMemory<float> &x, int incx, 725 const DeviceMemory<float> &y, int incy, 726 DeviceMemory<float> *ap) = 0; 727 virtual bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n, 728 double alpha, const DeviceMemory<double> &x, int incx, 729 const DeviceMemory<double> &y, int incy, 730 DeviceMemory<double> *ap) = 0; 731 732 // Computes a matrix-vector product for a symmetric matrix. 733 // 734 // y <- alpha * a * x + beta * y, 735 // 736 // alpha and beta are scalars; a is an n-by-n symmetric matrix; x and y are 737 // n-element vectors. 738 virtual bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n, 739 float alpha, const DeviceMemory<float> &a, int lda, 740 const DeviceMemory<float> &x, int incx, float beta, 741 DeviceMemory<float> *y, int incy) = 0; 742 virtual bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n, 743 double alpha, const DeviceMemory<double> &a, int lda, 744 const DeviceMemory<double> &x, int incx, double beta, 745 DeviceMemory<double> *y, int incy) = 0; 746 747 // Performs a rank-1 update of a symmetric matrix. 748 // 749 // a <- alpha * x * x' + a, 750 // 751 // alpha is a scalar; x is an n-element vector; a is an n-by-n symmetric 752 // matrix. 753 virtual bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n, 754 float alpha, const DeviceMemory<float> &x, int incx, 755 DeviceMemory<float> *a, int lda) = 0; 756 virtual bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n, 757 double alpha, const DeviceMemory<double> &x, int incx, 758 DeviceMemory<double> *a, int lda) = 0; 759 760 // Performs a rank-2 update of symmetric matrix. 761 // 762 // a <- alpha * x * x' + alpha * y * x' + a, 763 // 764 // alpha is a scalar; x and y are n-element vectors; a is an n-by-n symmetric 765 // matrix. 766 virtual bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n, 767 float alpha, const DeviceMemory<float> &x, int incx, 768 const DeviceMemory<float> &y, int incy, 769 DeviceMemory<float> *a, int lda) = 0; 770 virtual bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n, 771 double alpha, const DeviceMemory<double> &x, int incx, 772 const DeviceMemory<double> &y, int incy, 773 DeviceMemory<double> *a, int lda) = 0; 774 775 // Computes a matrix-vector product using a triangular band matrix. 776 // 777 // x <- a * x, 778 // or 779 // x <- a' * x, 780 // or 781 // x <- conj(a') * x, 782 // 783 // a is an n-by-n unit, or non-unit, upper or lower triangular band matrix, 784 // with k+1 diagonals; x is a n-element vector. 785 virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, 786 blas::Transpose trans, blas::Diagonal diag, uint64 n, 787 uint64 k, const DeviceMemory<float> &a, int lda, 788 DeviceMemory<float> *x, int incx) = 0; 789 virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, 790 blas::Transpose trans, blas::Diagonal diag, uint64 n, 791 uint64 k, const DeviceMemory<double> &a, int lda, 792 DeviceMemory<double> *x, int incx) = 0; 793 virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, 794 blas::Transpose trans, blas::Diagonal diag, uint64 n, 795 uint64 k, const DeviceMemory<std::complex<float>> &a, 796 int lda, DeviceMemory<std::complex<float>> *x, 797 int incx) = 0; 798 virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, 799 blas::Transpose trans, blas::Diagonal diag, uint64 n, 800 uint64 k, const DeviceMemory<std::complex<double>> &a, 801 int lda, DeviceMemory<std::complex<double>> *x, 802 int incx) = 0; 803 804 // Solves a system of linear equations whose coefficients are in a triangular 805 // band matrix as below: 806 // 807 // a * x = b, 808 // or 809 // a' * x = b, 810 // or 811 // conj(a') * x = b, 812 // 813 // b and x are n-element vectors; a is an n-by-n unit, or non-unit, upper or 814 // lower triangular band matrix, with k+1 diagonals. 815 virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, 816 blas::Transpose trans, blas::Diagonal diag, uint64 n, 817 uint64 k, const DeviceMemory<float> &a, int lda, 818 DeviceMemory<float> *x, int incx) = 0; 819 virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, 820 blas::Transpose trans, blas::Diagonal diag, uint64 n, 821 uint64 k, const DeviceMemory<double> &a, int lda, 822 DeviceMemory<double> *x, int incx) = 0; 823 virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, 824 blas::Transpose trans, blas::Diagonal diag, uint64 n, 825 uint64 k, const DeviceMemory<std::complex<float>> &a, 826 int lda, DeviceMemory<std::complex<float>> *x, 827 int incx) = 0; 828 virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, 829 blas::Transpose trans, blas::Diagonal diag, uint64 n, 830 uint64 k, const DeviceMemory<std::complex<double>> &a, 831 int lda, DeviceMemory<std::complex<double>> *x, 832 int incx) = 0; 833 834 // Computes a matrix-vector product using a triangular packed matrix. 835 // 836 // x <- a * x, 837 // or 838 // x <- a' * x, 839 // or 840 // x <- conj(a') * x, 841 // 842 // a is an n-by-n unit, or non-unit, upper or lower triangular matrix, 843 // supplied in packed form; x is a n-element vector. 844 virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, 845 blas::Transpose trans, blas::Diagonal diag, uint64 n, 846 const DeviceMemory<float> &ap, DeviceMemory<float> *x, 847 int incx) = 0; 848 virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, 849 blas::Transpose trans, blas::Diagonal diag, uint64 n, 850 const DeviceMemory<double> &ap, 851 DeviceMemory<double> *x, int incx) = 0; 852 virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, 853 blas::Transpose trans, blas::Diagonal diag, uint64 n, 854 const DeviceMemory<std::complex<float>> &ap, 855 DeviceMemory<std::complex<float>> *x, int incx) = 0; 856 virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, 857 blas::Transpose trans, blas::Diagonal diag, uint64 n, 858 const DeviceMemory<std::complex<double>> &ap, 859 DeviceMemory<std::complex<double>> *x, int incx) = 0; 860 861 // Solves a system of linear equations whose coefficients are in a triangular 862 // packed matrix as below: 863 // 864 // a * x = b, 865 // or 866 // a' * x = b, 867 // or 868 // conj(a') * x = b, 869 // 870 // b and x are n-element vectors; a is an n-by-n unit, or non-unit, upper or 871 // lower triangular matrix, supplied in packed form. 872 virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, 873 blas::Transpose trans, blas::Diagonal diag, uint64 n, 874 const DeviceMemory<float> &ap, DeviceMemory<float> *x, 875 int incx) = 0; 876 virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, 877 blas::Transpose trans, blas::Diagonal diag, uint64 n, 878 const DeviceMemory<double> &ap, 879 DeviceMemory<double> *x, int incx) = 0; 880 virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, 881 blas::Transpose trans, blas::Diagonal diag, uint64 n, 882 const DeviceMemory<std::complex<float>> &ap, 883 DeviceMemory<std::complex<float>> *x, int incx) = 0; 884 virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, 885 blas::Transpose trans, blas::Diagonal diag, uint64 n, 886 const DeviceMemory<std::complex<double>> &ap, 887 DeviceMemory<std::complex<double>> *x, int incx) = 0; 888 889 // Computes a matrix-vector product using a triangular matrix. 890 // 891 // x <- a * x, 892 // or 893 // x <- a' * x, 894 // or 895 // x <- conj(a') * x, 896 // 897 // a is an n-by-n unit, or non-unit, upper or lower triangular matrix; x is a 898 // n-element vector. 899 virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, 900 blas::Transpose trans, blas::Diagonal diag, uint64 n, 901 const DeviceMemory<float> &a, int lda, 902 DeviceMemory<float> *x, int incx) = 0; 903 virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, 904 blas::Transpose trans, blas::Diagonal diag, uint64 n, 905 const DeviceMemory<double> &a, int lda, 906 DeviceMemory<double> *x, int incx) = 0; 907 virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, 908 blas::Transpose trans, blas::Diagonal diag, uint64 n, 909 const DeviceMemory<std::complex<float>> &a, int lda, 910 DeviceMemory<std::complex<float>> *x, int incx) = 0; 911 virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, 912 blas::Transpose trans, blas::Diagonal diag, uint64 n, 913 const DeviceMemory<std::complex<double>> &a, int lda, 914 DeviceMemory<std::complex<double>> *x, int incx) = 0; 915 916 // Solves a system of linear equations whose coefficients are in a triangular 917 // matrix as below: 918 // 919 // a * x = b, 920 // or 921 // a' * x = b, 922 // or 923 // conj(a') * x = b, 924 // 925 // b and x are n-element vectors; a is an n-by-n unit, or non-unit, upper or 926 // lower triangular matrix. 927 virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, 928 blas::Transpose trans, blas::Diagonal diag, uint64 n, 929 const DeviceMemory<float> &a, int lda, 930 DeviceMemory<float> *x, int incx) = 0; 931 virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, 932 blas::Transpose trans, blas::Diagonal diag, uint64 n, 933 const DeviceMemory<double> &a, int lda, 934 DeviceMemory<double> *x, int incx) = 0; 935 virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, 936 blas::Transpose trans, blas::Diagonal diag, uint64 n, 937 const DeviceMemory<std::complex<float>> &a, int lda, 938 DeviceMemory<std::complex<float>> *x, int incx) = 0; 939 virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, 940 blas::Transpose trans, blas::Diagonal diag, uint64 n, 941 const DeviceMemory<std::complex<double>> &a, int lda, 942 DeviceMemory<std::complex<double>> *x, int incx) = 0; 943 944 // Computes a matrix-matrix product with general matrices: 945 // 946 // c <- alpha * op(a) * op(b) + beta * c, 947 // 948 // op(X) is one of op(X) = X, or op(X) = X', or op(X) = conj(X'); alpha and 949 // beta are scalars; a, b, and c are matrices; op(a) is an m-by-k matrix; 950 // op(b) is a k-by-n matrix; c is an m-by-n matrix. 951 // 952 // Note: The half interface uses float precision internally; the version 953 // that uses half precision internally is not yet supported. There is no 954 // batched version of the half-precision interface. 955 virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa, 956 blas::Transpose transb, uint64 m, uint64 n, uint64 k, 957 float alpha, const DeviceMemory<Eigen::half> &a, 958 int lda, const DeviceMemory<Eigen::half> &b, int ldb, 959 float beta, DeviceMemory<Eigen::half> *c, 960 int ldc) = 0; 961 virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa, 962 blas::Transpose transb, uint64 m, uint64 n, uint64 k, 963 float alpha, const DeviceMemory<float> &a, int lda, 964 const DeviceMemory<float> &b, int ldb, float beta, 965 DeviceMemory<float> *c, int ldc) = 0; 966 virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa, 967 blas::Transpose transb, uint64 m, uint64 n, uint64 k, 968 double alpha, const DeviceMemory<double> &a, int lda, 969 const DeviceMemory<double> &b, int ldb, double beta, 970 DeviceMemory<double> *c, int ldc) = 0; 971 virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa, 972 blas::Transpose transb, uint64 m, uint64 n, uint64 k, 973 std::complex<float> alpha, 974 const DeviceMemory<std::complex<float>> &a, int lda, 975 const DeviceMemory<std::complex<float>> &b, int ldb, 976 std::complex<float> beta, 977 DeviceMemory<std::complex<float>> *c, int ldc) = 0; 978 virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa, 979 blas::Transpose transb, uint64 m, uint64 n, uint64 k, 980 std::complex<double> alpha, 981 const DeviceMemory<std::complex<double>> &a, int lda, 982 const DeviceMemory<std::complex<double>> &b, int ldb, 983 std::complex<double> beta, 984 DeviceMemory<std::complex<double>> *c, int ldc) = 0; 985 986 virtual bool DoBlasGemmWithProfiling( 987 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 988 uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, 989 int lda, const DeviceMemory<Eigen::half> &b, int ldb, float beta, 990 DeviceMemory<Eigen::half> *c, int ldc, 991 ProfileResult *output_profile_result) = 0; 992 virtual bool DoBlasGemmWithProfiling( 993 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 994 uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda, 995 const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c, 996 int ldc, ProfileResult *output_profile_result) = 0; 997 virtual bool DoBlasGemmWithProfiling( 998 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 999 uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda, 1000 const DeviceMemory<double> &b, int ldb, double beta, 1001 DeviceMemory<double> *c, int ldc, 1002 ProfileResult *output_profile_result) = 0; 1003 virtual bool DoBlasGemmWithProfiling( 1004 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1005 uint64 n, uint64 k, std::complex<float> alpha, 1006 const DeviceMemory<std::complex<float>> &a, int lda, 1007 const DeviceMemory<std::complex<float>> &b, int ldb, 1008 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, 1009 ProfileResult *output_profile_result) = 0; 1010 virtual bool DoBlasGemmWithProfiling( 1011 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1012 uint64 n, uint64 k, std::complex<double> alpha, 1013 const DeviceMemory<std::complex<double>> &a, int lda, 1014 const DeviceMemory<std::complex<double>> &b, int ldb, 1015 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, 1016 ProfileResult *output_profile_result) = 0; 1017 1018 // Gets a list of supported algorithms for DoBlasGemmWithAlgorithm. 1019 virtual bool GetBlasGemmAlgorithms( 1020 std::vector<AlgorithmType> *out_algorithms) = 0; 1021 1022 // Like DoBlasGemm, but accepts an algorithm and an compute type. 1023 // 1024 // The compute type lets you say (e.g.) that the inputs and outputs are 1025 // Eigen::halfs, but you want the internal computations to be done with 1026 // float32 precision. 1027 // 1028 // Note the subtle difference in the version that accepts Eigen:::half -- 1029 // alpha and beta have type const Eigen::half&, not float. 1030 // 1031 // If output_profile_result is not null, a failure here does not put the 1032 // stream in a failure state. Instead, success/failure is indicated by 1033 // output_profile_result->is_valid(). This lets you use this function for 1034 // choosing the best algorithm among many (some of which may fail) without 1035 // creating a new Stream for each attempt. 1036 virtual bool DoBlasGemmWithAlgorithm( 1037 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1038 uint64 n, uint64 k, const HostOrDeviceScalar<int> &alpha, 1039 const DeviceMemory<int8> &a, int lda, const DeviceMemory<int8> &b, 1040 int ldb, const HostOrDeviceScalar<int> &beta, DeviceMemory<int32> *c, 1041 int ldc, ComputationType computation_type, AlgorithmType algorithm, 1042 ProfileResult *output_profile_result) = 0; 1043 virtual bool DoBlasGemmWithAlgorithm( 1044 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1045 uint64 n, uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha, 1046 const DeviceMemory<Eigen::half> &a, int lda, 1047 const DeviceMemory<Eigen::half> &b, int ldb, 1048 const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c, 1049 int ldc, ComputationType computation_type, AlgorithmType algorithm, 1050 ProfileResult *output_profile_result) = 0; 1051 virtual bool DoBlasGemmWithAlgorithm( 1052 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1053 uint64 n, uint64 k, const HostOrDeviceScalar<float> &alpha, 1054 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b, 1055 int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c, 1056 int ldc, ComputationType computation_type, AlgorithmType algorithm, 1057 ProfileResult *output_profile_result) = 0; 1058 virtual bool DoBlasGemmWithAlgorithm( 1059 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1060 uint64 n, uint64 k, const HostOrDeviceScalar<double> &alpha, 1061 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b, 1062 int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c, 1063 int ldc, ComputationType computation_type, AlgorithmType algorithm, 1064 ProfileResult *output_profile_result) = 0; 1065 virtual bool DoBlasGemmWithAlgorithm( 1066 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1067 uint64 n, uint64 k, const HostOrDeviceScalar<std::complex<float>> &alpha, 1068 const DeviceMemory<std::complex<float>> &a, int lda, 1069 const DeviceMemory<std::complex<float>> &b, int ldb, 1070 const HostOrDeviceScalar<std::complex<float>> &beta, 1071 DeviceMemory<std::complex<float>> *c, int ldc, 1072 ComputationType computation_type, AlgorithmType algorithm, 1073 ProfileResult *output_profile_result) = 0; 1074 virtual bool DoBlasGemmWithAlgorithm( 1075 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1076 uint64 n, uint64 k, const HostOrDeviceScalar<std::complex<double>> &alpha, 1077 const DeviceMemory<std::complex<double>> &a, int lda, 1078 const DeviceMemory<std::complex<double>> &b, int ldb, 1079 const HostOrDeviceScalar<std::complex<double>> &beta, 1080 DeviceMemory<std::complex<double>> *c, int ldc, 1081 ComputationType computation_type, AlgorithmType algorithm, 1082 ProfileResult *output_profile_result) = 0; 1083 1084 // Computes a batch of matrix-matrix product with general matrices. 1085 // This is a batched version of DoBlasGemm. 1086 // The batched GEMM computes matrix product for each input/output in a, b, 1087 // and c, which contain batch_count DeviceMemory objects. 1088 virtual bool DoBlasGemmBatched( 1089 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1090 uint64 n, uint64 k, float alpha, 1091 const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda, 1092 const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, 1093 float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, 1094 int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0; 1095 virtual bool DoBlasGemmBatched( 1096 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1097 uint64 n, uint64 k, float alpha, 1098 const port::ArraySlice<DeviceMemory<float> *> &a, int lda, 1099 const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, float beta, 1100 const port::ArraySlice<DeviceMemory<float> *> &c, int ldc, 1101 int batch_count, ScratchAllocator *scratch_allocator) = 0; 1102 virtual bool DoBlasGemmBatched( 1103 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1104 uint64 n, uint64 k, double alpha, 1105 const port::ArraySlice<DeviceMemory<double> *> &a, int lda, 1106 const port::ArraySlice<DeviceMemory<double> *> &b, int ldb, double beta, 1107 const port::ArraySlice<DeviceMemory<double> *> &c, int ldc, 1108 int batch_count, ScratchAllocator *scratch_allocator) = 0; 1109 virtual bool DoBlasGemmBatched( 1110 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1111 uint64 n, uint64 k, std::complex<float> alpha, 1112 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda, 1113 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb, 1114 std::complex<float> beta, 1115 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc, 1116 int batch_count, ScratchAllocator *scratch_allocator) = 0; 1117 virtual bool DoBlasGemmBatched( 1118 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1119 uint64 n, uint64 k, std::complex<double> alpha, 1120 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda, 1121 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb, 1122 std::complex<double> beta, 1123 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc, 1124 int batch_count, ScratchAllocator *scratch_allocator) = 0; 1125 1126 // Batched gemm with strides instead of pointer arrays. 1127 virtual bool DoBlasGemmStridedBatched( 1128 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1129 uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, 1130 int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb, 1131 int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc, 1132 int64 stride_c, int batch_count) = 0; 1133 virtual bool DoBlasGemmStridedBatched( 1134 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1135 uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda, 1136 int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b, 1137 float beta, DeviceMemory<float> *c, int ldc, int64 stride_c, 1138 int batch_count) = 0; 1139 virtual bool DoBlasGemmStridedBatched( 1140 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1141 uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda, 1142 int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b, 1143 double beta, DeviceMemory<double> *c, int ldc, int64 stride_c, 1144 int batch_count) = 0; 1145 virtual bool DoBlasGemmStridedBatched( 1146 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1147 uint64 n, uint64 k, std::complex<float> alpha, 1148 const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a, 1149 const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b, 1150 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, 1151 int64 stride_c, int batch_count) = 0; 1152 virtual bool DoBlasGemmStridedBatched( 1153 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1154 uint64 n, uint64 k, std::complex<double> alpha, 1155 const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a, 1156 const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b, 1157 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, 1158 int64 stride_c, int batch_count) = 0; 1159 1160 // Computes a matrix-matrix product where one input matrix is Hermitian: 1161 // 1162 // c <- alpha * a * b + beta * c, 1163 // or 1164 // c <- alpha * b * a + beta * c, 1165 // 1166 // alpha and beta are scalars; a is a Hermitian matrix; b and c are m-by-n 1167 // matrices. 1168 virtual bool DoBlasHemm(Stream *stream, blas::Side side, 1169 blas::UpperLower uplo, uint64 m, uint64 n, 1170 std::complex<float> alpha, 1171 const DeviceMemory<std::complex<float>> &a, int lda, 1172 const DeviceMemory<std::complex<float>> &b, int ldb, 1173 std::complex<float> beta, 1174 DeviceMemory<std::complex<float>> *c, int ldc) = 0; 1175 virtual bool DoBlasHemm(Stream *stream, blas::Side side, 1176 blas::UpperLower uplo, uint64 m, uint64 n, 1177 std::complex<double> alpha, 1178 const DeviceMemory<std::complex<double>> &a, int lda, 1179 const DeviceMemory<std::complex<double>> &b, int ldb, 1180 std::complex<double> beta, 1181 DeviceMemory<std::complex<double>> *c, int ldc) = 0; 1182 1183 // Performs a Hermitian rank-k update. 1184 // 1185 // c <- alpha * a * conj(a') + beta * c, 1186 // or 1187 // c <- alpha * conj(a') * a + beta * c, 1188 // 1189 // alpha and beta are scalars; c is a n-by-n Hermitian matrix; a is an n-by-k 1190 // matrix in the first case and a k-by-n matrix in the second case. 1191 virtual bool DoBlasHerk(Stream *stream, blas::UpperLower uplo, 1192 blas::Transpose trans, uint64 n, uint64 k, 1193 float alpha, 1194 const DeviceMemory<std::complex<float>> &a, int lda, 1195 float beta, DeviceMemory<std::complex<float>> *c, 1196 int ldc) = 0; 1197 virtual bool DoBlasHerk(Stream *stream, blas::UpperLower uplo, 1198 blas::Transpose trans, uint64 n, uint64 k, 1199 double alpha, 1200 const DeviceMemory<std::complex<double>> &a, int lda, 1201 double beta, DeviceMemory<std::complex<double>> *c, 1202 int ldc) = 0; 1203 1204 // Performs a Hermitian rank-2k update. 1205 // 1206 // c <- alpha * a * conj(b') + conj(alpha) * b * conj(a') + beta * c, 1207 // or 1208 // c <- alpha * conj(b') * a + conj(alpha) * conj(a') * b + beta * c, 1209 // 1210 // alpha and beta are scalars; c is a n-by-n Hermitian matrix; a and b are 1211 // n-by-k matrices in the first case and k-by-n matrices in the second case. 1212 virtual bool DoBlasHer2k(Stream *stream, blas::UpperLower uplo, 1213 blas::Transpose trans, uint64 n, uint64 k, 1214 std::complex<float> alpha, 1215 const DeviceMemory<std::complex<float>> &a, int lda, 1216 const DeviceMemory<std::complex<float>> &b, int ldb, 1217 float beta, DeviceMemory<std::complex<float>> *c, 1218 int ldc) = 0; 1219 virtual bool DoBlasHer2k(Stream *stream, blas::UpperLower uplo, 1220 blas::Transpose trans, uint64 n, uint64 k, 1221 std::complex<double> alpha, 1222 const DeviceMemory<std::complex<double>> &a, int lda, 1223 const DeviceMemory<std::complex<double>> &b, int ldb, 1224 double beta, DeviceMemory<std::complex<double>> *c, 1225 int ldc) = 0; 1226 1227 // Computes a matrix-matrix product where one input matrix is symmetric. 1228 // 1229 // c <- alpha * a * b + beta * c, 1230 // or 1231 // c <- alpha * b * a + beta * c, 1232 // 1233 // alpha and beta are scalars; a is a symmetric matrix; b and c are m-by-n 1234 // matrices. 1235 virtual bool DoBlasSymm(Stream *stream, blas::Side side, 1236 blas::UpperLower uplo, uint64 m, uint64 n, 1237 float alpha, const DeviceMemory<float> &a, int lda, 1238 const DeviceMemory<float> &b, int ldb, float beta, 1239 DeviceMemory<float> *c, int ldc) = 0; 1240 virtual bool DoBlasSymm(Stream *stream, blas::Side side, 1241 blas::UpperLower uplo, uint64 m, uint64 n, 1242 double alpha, const DeviceMemory<double> &a, int lda, 1243 const DeviceMemory<double> &b, int ldb, double beta, 1244 DeviceMemory<double> *c, int ldc) = 0; 1245 virtual bool DoBlasSymm(Stream *stream, blas::Side side, 1246 blas::UpperLower uplo, uint64 m, uint64 n, 1247 std::complex<float> alpha, 1248 const DeviceMemory<std::complex<float>> &a, int lda, 1249 const DeviceMemory<std::complex<float>> &b, int ldb, 1250 std::complex<float> beta, 1251 DeviceMemory<std::complex<float>> *c, int ldc) = 0; 1252 virtual bool DoBlasSymm(Stream *stream, blas::Side side, 1253 blas::UpperLower uplo, uint64 m, uint64 n, 1254 std::complex<double> alpha, 1255 const DeviceMemory<std::complex<double>> &a, int lda, 1256 const DeviceMemory<std::complex<double>> &b, int ldb, 1257 std::complex<double> beta, 1258 DeviceMemory<std::complex<double>> *c, int ldc) = 0; 1259 1260 // Performs a symmetric rank-k update. 1261 // 1262 // c <- alpha * a * a' + beta * c, 1263 // or 1264 // c <- alpha * a' * a + beta * c, 1265 // 1266 // alpha and beta are scalars; c is a n-by-n symmetric matrix; a is an n-by-k 1267 // matrix in the first case and a k-by-n matrix in the second case. 1268 virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, 1269 blas::Transpose trans, uint64 n, uint64 k, 1270 float alpha, const DeviceMemory<float> &a, int lda, 1271 float beta, DeviceMemory<float> *c, int ldc) = 0; 1272 virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, 1273 blas::Transpose trans, uint64 n, uint64 k, 1274 double alpha, const DeviceMemory<double> &a, int lda, 1275 double beta, DeviceMemory<double> *c, int ldc) = 0; 1276 virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, 1277 blas::Transpose trans, uint64 n, uint64 k, 1278 std::complex<float> alpha, 1279 const DeviceMemory<std::complex<float>> &a, int lda, 1280 std::complex<float> beta, 1281 DeviceMemory<std::complex<float>> *c, int ldc) = 0; 1282 virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, 1283 blas::Transpose trans, uint64 n, uint64 k, 1284 std::complex<double> alpha, 1285 const DeviceMemory<std::complex<double>> &a, int lda, 1286 std::complex<double> beta, 1287 DeviceMemory<std::complex<double>> *c, int ldc) = 0; 1288 1289 // Performs a symmetric rank-2k update. 1290 // 1291 // c <- alpha * a * b' + alpha * b * a' + beta * c, 1292 // or 1293 // c <- alpha * b' * a + alpha * a' * b + beta * c, 1294 // 1295 // alpha and beta are scalars; c is a n-by-n symmetric matrix; a and b are 1296 // n-by-k matrices in the first case and k-by-n matrices in the second case. 1297 virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, 1298 blas::Transpose trans, uint64 n, uint64 k, 1299 float alpha, const DeviceMemory<float> &a, int lda, 1300 const DeviceMemory<float> &b, int ldb, float beta, 1301 DeviceMemory<float> *c, int ldc) = 0; 1302 virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, 1303 blas::Transpose trans, uint64 n, uint64 k, 1304 double alpha, const DeviceMemory<double> &a, int lda, 1305 const DeviceMemory<double> &b, int ldb, double beta, 1306 DeviceMemory<double> *c, int ldc) = 0; 1307 virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, 1308 blas::Transpose trans, uint64 n, uint64 k, 1309 std::complex<float> alpha, 1310 const DeviceMemory<std::complex<float>> &a, int lda, 1311 const DeviceMemory<std::complex<float>> &b, int ldb, 1312 std::complex<float> beta, 1313 DeviceMemory<std::complex<float>> *c, int ldc) = 0; 1314 virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, 1315 blas::Transpose trans, uint64 n, uint64 k, 1316 std::complex<double> alpha, 1317 const DeviceMemory<std::complex<double>> &a, int lda, 1318 const DeviceMemory<std::complex<double>> &b, int ldb, 1319 std::complex<double> beta, 1320 DeviceMemory<std::complex<double>> *c, int ldc) = 0; 1321 1322 // Computes a matrix-matrix product where one input matrix is triangular. 1323 // 1324 // b <- alpha * op(a) * b, 1325 // or 1326 // b <- alpha * b * op(a) 1327 // 1328 // alpha is a scalar; b is an m-by-n matrix; a is a unit, or non-unit, upper 1329 // or lower triangular matrix; op(a) is one of op(a) = a, or op(a) = a', or 1330 // op(a) = conj(a'). 1331 virtual bool DoBlasTrmm(Stream *stream, blas::Side side, 1332 blas::UpperLower uplo, blas::Transpose transa, 1333 blas::Diagonal diag, uint64 m, uint64 n, float alpha, 1334 const DeviceMemory<float> &a, int lda, 1335 DeviceMemory<float> *b, int ldb) = 0; 1336 virtual bool DoBlasTrmm(Stream *stream, blas::Side side, 1337 blas::UpperLower uplo, blas::Transpose transa, 1338 blas::Diagonal diag, uint64 m, uint64 n, double alpha, 1339 const DeviceMemory<double> &a, int lda, 1340 DeviceMemory<double> *b, int ldb) = 0; 1341 virtual bool DoBlasTrmm(Stream *stream, blas::Side side, 1342 blas::UpperLower uplo, blas::Transpose transa, 1343 blas::Diagonal diag, uint64 m, uint64 n, 1344 std::complex<float> alpha, 1345 const DeviceMemory<std::complex<float>> &a, int lda, 1346 DeviceMemory<std::complex<float>> *b, int ldb) = 0; 1347 virtual bool DoBlasTrmm(Stream *stream, blas::Side side, 1348 blas::UpperLower uplo, blas::Transpose transa, 1349 blas::Diagonal diag, uint64 m, uint64 n, 1350 std::complex<double> alpha, 1351 const DeviceMemory<std::complex<double>> &a, int lda, 1352 DeviceMemory<std::complex<double>> *b, int ldb) = 0; 1353 1354 // Solves a triangular matrix equation. 1355 // 1356 // op(a) * x = alpha * b, 1357 // or 1358 // x * op(a) = alpha * b 1359 // 1360 // alpha is a scalar; x and b are m-by-n matrices; a is a unit, or non-unit, 1361 // upper or lower triangular matrix; op(a) is one of op(a) = a, or op(a) = a', 1362 // or op(a) = conj(a'). 1363 virtual bool DoBlasTrsm(Stream *stream, blas::Side side, 1364 blas::UpperLower uplo, blas::Transpose transa, 1365 blas::Diagonal diag, uint64 m, uint64 n, float alpha, 1366 const DeviceMemory<float> &a, int lda, 1367 DeviceMemory<float> *b, int ldb) = 0; 1368 virtual bool DoBlasTrsm(Stream *stream, blas::Side side, 1369 blas::UpperLower uplo, blas::Transpose transa, 1370 blas::Diagonal diag, uint64 m, uint64 n, double alpha, 1371 const DeviceMemory<double> &a, int lda, 1372 DeviceMemory<double> *b, int ldb) = 0; 1373 virtual bool DoBlasTrsm(Stream *stream, blas::Side side, 1374 blas::UpperLower uplo, blas::Transpose transa, 1375 blas::Diagonal diag, uint64 m, uint64 n, 1376 std::complex<float> alpha, 1377 const DeviceMemory<std::complex<float>> &a, int lda, 1378 DeviceMemory<std::complex<float>> *b, int ldb) = 0; 1379 virtual bool DoBlasTrsm(Stream *stream, blas::Side side, 1380 blas::UpperLower uplo, blas::Transpose transa, 1381 blas::Diagonal diag, uint64 m, uint64 n, 1382 std::complex<double> alpha, 1383 const DeviceMemory<std::complex<double>> &a, int lda, 1384 DeviceMemory<std::complex<double>> *b, int ldb) = 0; 1385 1386 virtual port::Status GetVersion(string *version) = 0; 1387 1388 protected: BlasSupport()1389 BlasSupport() {} 1390 1391 private: 1392 SE_DISALLOW_COPY_AND_ASSIGN(BlasSupport); 1393 }; 1394 1395 // Macro used to quickly declare overrides for abstract virtuals in the 1396 // BlasSupport base class. 1397 #define TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES \ 1398 bool DoBlasAsum(Stream *stream, uint64 elem_count, \ 1399 const DeviceMemory<float> &x, int incx, \ 1400 DeviceMemory<float> *result) override; \ 1401 bool DoBlasAsum(Stream *stream, uint64 elem_count, \ 1402 const DeviceMemory<double> &x, int incx, \ 1403 DeviceMemory<double> *result) override; \ 1404 bool DoBlasAsum(Stream *stream, uint64 elem_count, \ 1405 const DeviceMemory<std::complex<float>> &x, int incx, \ 1406 DeviceMemory<float> *result) override; \ 1407 bool DoBlasAsum(Stream *stream, uint64 elem_count, \ 1408 const DeviceMemory<std::complex<double>> &x, int incx, \ 1409 DeviceMemory<double> *result) override; \ 1410 bool DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha, \ 1411 const DeviceMemory<float> &x, int incx, \ 1412 DeviceMemory<float> *y, int incy) override; \ 1413 bool DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha, \ 1414 const DeviceMemory<double> &x, int incx, \ 1415 DeviceMemory<double> *y, int incy) override; \ 1416 bool DoBlasAxpy(Stream *stream, uint64 elem_count, \ 1417 std::complex<float> alpha, \ 1418 const DeviceMemory<std::complex<float>> &x, int incx, \ 1419 DeviceMemory<std::complex<float>> *y, int incy) override; \ 1420 bool DoBlasAxpy(Stream *stream, uint64 elem_count, \ 1421 std::complex<double> alpha, \ 1422 const DeviceMemory<std::complex<double>> &x, int incx, \ 1423 DeviceMemory<std::complex<double>> *y, int incy) override; \ 1424 bool DoBlasCopy(Stream *stream, uint64 elem_count, \ 1425 const DeviceMemory<float> &x, int incx, \ 1426 DeviceMemory<float> *y, int incy) override; \ 1427 bool DoBlasCopy(Stream *stream, uint64 elem_count, \ 1428 const DeviceMemory<double> &x, int incx, \ 1429 DeviceMemory<double> *y, int incy) override; \ 1430 bool DoBlasCopy(Stream *stream, uint64 elem_count, \ 1431 const DeviceMemory<std::complex<float>> &x, int incx, \ 1432 DeviceMemory<std::complex<float>> *y, int incy) override; \ 1433 bool DoBlasCopy(Stream *stream, uint64 elem_count, \ 1434 const DeviceMemory<std::complex<double>> &x, int incx, \ 1435 DeviceMemory<std::complex<double>> *y, int incy) override; \ 1436 bool DoBlasDot(Stream *stream, uint64 elem_count, \ 1437 const DeviceMemory<float> &x, int incx, \ 1438 const DeviceMemory<float> &y, int incy, \ 1439 DeviceMemory<float> *result) override; \ 1440 bool DoBlasDot(Stream *stream, uint64 elem_count, \ 1441 const DeviceMemory<double> &x, int incx, \ 1442 const DeviceMemory<double> &y, int incy, \ 1443 DeviceMemory<double> *result) override; \ 1444 bool DoBlasDotc(Stream *stream, uint64 elem_count, \ 1445 const DeviceMemory<std::complex<float>> &x, int incx, \ 1446 const DeviceMemory<std::complex<float>> &y, int incy, \ 1447 DeviceMemory<std::complex<float>> *result) override; \ 1448 bool DoBlasDotc(Stream *stream, uint64 elem_count, \ 1449 const DeviceMemory<std::complex<double>> &x, int incx, \ 1450 const DeviceMemory<std::complex<double>> &y, int incy, \ 1451 DeviceMemory<std::complex<double>> *result) override; \ 1452 bool DoBlasDotu(Stream *stream, uint64 elem_count, \ 1453 const DeviceMemory<std::complex<float>> &x, int incx, \ 1454 const DeviceMemory<std::complex<float>> &y, int incy, \ 1455 DeviceMemory<std::complex<float>> *result) override; \ 1456 bool DoBlasDotu(Stream *stream, uint64 elem_count, \ 1457 const DeviceMemory<std::complex<double>> &x, int incx, \ 1458 const DeviceMemory<std::complex<double>> &y, int incy, \ 1459 DeviceMemory<std::complex<double>> *result) override; \ 1460 bool DoBlasNrm2(Stream *stream, uint64 elem_count, \ 1461 const DeviceMemory<float> &x, int incx, \ 1462 DeviceMemory<float> *result) override; \ 1463 bool DoBlasNrm2(Stream *stream, uint64 elem_count, \ 1464 const DeviceMemory<double> &x, int incx, \ 1465 DeviceMemory<double> *result) override; \ 1466 bool DoBlasNrm2(Stream *stream, uint64 elem_count, \ 1467 const DeviceMemory<std::complex<float>> &x, int incx, \ 1468 DeviceMemory<float> *result) override; \ 1469 bool DoBlasNrm2(Stream *stream, uint64 elem_count, \ 1470 const DeviceMemory<std::complex<double>> &x, int incx, \ 1471 DeviceMemory<double> *result) override; \ 1472 bool DoBlasRot(Stream *stream, uint64 elem_count, DeviceMemory<float> *x, \ 1473 int incx, DeviceMemory<float> *y, int incy, float c, float s) \ 1474 override; \ 1475 bool DoBlasRot(Stream *stream, uint64 elem_count, DeviceMemory<double> *x, \ 1476 int incx, DeviceMemory<double> *y, int incy, double c, \ 1477 double s) override; \ 1478 bool DoBlasRot(Stream *stream, uint64 elem_count, \ 1479 DeviceMemory<std::complex<float>> *x, int incx, \ 1480 DeviceMemory<std::complex<float>> *y, int incy, float c, \ 1481 float s) override; \ 1482 bool DoBlasRot(Stream *stream, uint64 elem_count, \ 1483 DeviceMemory<std::complex<double>> *x, int incx, \ 1484 DeviceMemory<std::complex<double>> *y, int incy, double c, \ 1485 double s) override; \ 1486 bool DoBlasRotg(Stream *stream, DeviceMemory<float> *a, \ 1487 DeviceMemory<float> *b, DeviceMemory<float> *c, \ 1488 DeviceMemory<float> *s) override; \ 1489 bool DoBlasRotg(Stream *stream, DeviceMemory<double> *a, \ 1490 DeviceMemory<double> *b, DeviceMemory<double> *c, \ 1491 DeviceMemory<double> *s) override; \ 1492 bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<float>> *a, \ 1493 DeviceMemory<std::complex<float>> *b, \ 1494 DeviceMemory<float> *c, \ 1495 DeviceMemory<std::complex<float>> *s) override; \ 1496 bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<double>> *a, \ 1497 DeviceMemory<std::complex<double>> *b, \ 1498 DeviceMemory<double> *c, \ 1499 DeviceMemory<std::complex<double>> *s) override; \ 1500 bool DoBlasRotm(Stream *stream, uint64 elem_count, DeviceMemory<float> *x, \ 1501 int incx, DeviceMemory<float> *y, int incy, \ 1502 const DeviceMemory<float> ¶m) override; \ 1503 bool DoBlasRotm(Stream *stream, uint64 elem_count, DeviceMemory<double> *x, \ 1504 int incx, DeviceMemory<double> *y, int incy, \ 1505 const DeviceMemory<double> ¶m) override; \ 1506 bool DoBlasRotmg(Stream *stream, DeviceMemory<float> *d1, \ 1507 DeviceMemory<float> *d2, DeviceMemory<float> *x1, \ 1508 const DeviceMemory<float> &y1, DeviceMemory<float> *param) \ 1509 override; \ 1510 bool DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1, \ 1511 DeviceMemory<double> *d2, DeviceMemory<double> *x1, \ 1512 const DeviceMemory<double> &y1, \ 1513 DeviceMemory<double> *param) override; \ 1514 bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha, \ 1515 DeviceMemory<float> *x, int incx) override; \ 1516 bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha, \ 1517 DeviceMemory<double> *x, int incx) override; \ 1518 bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha, \ 1519 DeviceMemory<std::complex<float>> *x, int incx) override; \ 1520 bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha, \ 1521 DeviceMemory<std::complex<double>> *x, int incx) override; \ 1522 bool DoBlasScal(Stream *stream, uint64 elem_count, \ 1523 std::complex<float> alpha, \ 1524 DeviceMemory<std::complex<float>> *x, int incx) override; \ 1525 bool DoBlasScal(Stream *stream, uint64 elem_count, \ 1526 std::complex<double> alpha, \ 1527 DeviceMemory<std::complex<double>> *x, int incx) override; \ 1528 bool DoBlasSwap(Stream *stream, uint64 elem_count, DeviceMemory<float> *x, \ 1529 int incx, DeviceMemory<float> *y, int incy) override; \ 1530 bool DoBlasSwap(Stream *stream, uint64 elem_count, DeviceMemory<double> *x, \ 1531 int incx, DeviceMemory<double> *y, int incy) override; \ 1532 bool DoBlasSwap(Stream *stream, uint64 elem_count, \ 1533 DeviceMemory<std::complex<float>> *x, int incx, \ 1534 DeviceMemory<std::complex<float>> *y, int incy) override; \ 1535 bool DoBlasSwap(Stream *stream, uint64 elem_count, \ 1536 DeviceMemory<std::complex<double>> *x, int incx, \ 1537 DeviceMemory<std::complex<double>> *y, int incy) override; \ 1538 bool DoBlasIamax(Stream *stream, uint64 elem_count, \ 1539 const DeviceMemory<float> &x, int incx, \ 1540 DeviceMemory<int> *result) override; \ 1541 bool DoBlasIamax(Stream *stream, uint64 elem_count, \ 1542 const DeviceMemory<double> &x, int incx, \ 1543 DeviceMemory<int> *result) override; \ 1544 bool DoBlasIamax(Stream *stream, uint64 elem_count, \ 1545 const DeviceMemory<std::complex<float>> &x, int incx, \ 1546 DeviceMemory<int> *result) override; \ 1547 bool DoBlasIamax(Stream *stream, uint64 elem_count, \ 1548 const DeviceMemory<std::complex<double>> &x, int incx, \ 1549 DeviceMemory<int> *result) override; \ 1550 bool DoBlasIamin(Stream *stream, uint64 elem_count, \ 1551 const DeviceMemory<float> &x, int incx, \ 1552 DeviceMemory<int> *result) override; \ 1553 bool DoBlasIamin(Stream *stream, uint64 elem_count, \ 1554 const DeviceMemory<double> &x, int incx, \ 1555 DeviceMemory<int> *result) override; \ 1556 bool DoBlasIamin(Stream *stream, uint64 elem_count, \ 1557 const DeviceMemory<std::complex<float>> &x, int incx, \ 1558 DeviceMemory<int> *result) override; \ 1559 bool DoBlasIamin(Stream *stream, uint64 elem_count, \ 1560 const DeviceMemory<std::complex<double>> &x, int incx, \ 1561 DeviceMemory<int> *result) override; \ 1562 bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1563 uint64 kl, uint64 ku, float alpha, \ 1564 const DeviceMemory<float> &a, int lda, \ 1565 const DeviceMemory<float> &x, int incx, float beta, \ 1566 DeviceMemory<float> *y, int incy) override; \ 1567 bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1568 uint64 kl, uint64 ku, double alpha, \ 1569 const DeviceMemory<double> &a, int lda, \ 1570 const DeviceMemory<double> &x, int incx, double beta, \ 1571 DeviceMemory<double> *y, int incy) override; \ 1572 bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1573 uint64 kl, uint64 ku, std::complex<float> alpha, \ 1574 const DeviceMemory<std::complex<float>> &a, int lda, \ 1575 const DeviceMemory<std::complex<float>> &x, int incx, \ 1576 std::complex<float> beta, \ 1577 DeviceMemory<std::complex<float>> *y, int incy) override; \ 1578 bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1579 uint64 kl, uint64 ku, std::complex<double> alpha, \ 1580 const DeviceMemory<std::complex<double>> &a, int lda, \ 1581 const DeviceMemory<std::complex<double>> &x, int incx, \ 1582 std::complex<double> beta, \ 1583 DeviceMemory<std::complex<double>> *y, int incy) override; \ 1584 bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1585 float alpha, const DeviceMemory<float> &a, int lda, \ 1586 const DeviceMemory<float> &x, int incx, float beta, \ 1587 DeviceMemory<float> *y, int incy) override; \ 1588 bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1589 double alpha, const DeviceMemory<double> &a, int lda, \ 1590 const DeviceMemory<double> &x, int incx, double beta, \ 1591 DeviceMemory<double> *y, int incy) override; \ 1592 bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1593 std::complex<float> alpha, \ 1594 const DeviceMemory<std::complex<float>> &a, int lda, \ 1595 const DeviceMemory<std::complex<float>> &x, int incx, \ 1596 std::complex<float> beta, \ 1597 DeviceMemory<std::complex<float>> *y, int incy) override; \ 1598 bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1599 std::complex<double> alpha, \ 1600 const DeviceMemory<std::complex<double>> &a, int lda, \ 1601 const DeviceMemory<std::complex<double>> &x, int incx, \ 1602 std::complex<double> beta, \ 1603 DeviceMemory<std::complex<double>> *y, int incy) override; \ 1604 bool DoBlasGemvWithProfiling( \ 1605 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha, \ 1606 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x, \ 1607 int incx, float beta, DeviceMemory<float> *y, int incy, \ 1608 blas::ProfileResult *output_profile_result) override; \ 1609 bool DoBlasGemvWithProfiling( \ 1610 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha, \ 1611 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x, \ 1612 int incx, double beta, DeviceMemory<double> *y, int incy, \ 1613 blas::ProfileResult *output_profile_result) override; \ 1614 bool DoBlasGemvWithProfiling( \ 1615 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1616 std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a, \ 1617 int lda, const DeviceMemory<std::complex<float>> &x, int incx, \ 1618 std::complex<float> beta, DeviceMemory<std::complex<float>> *y, \ 1619 int incy, blas::ProfileResult *output_profile_result) override; \ 1620 bool DoBlasGemvWithProfiling( \ 1621 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1622 std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a, \ 1623 int lda, const DeviceMemory<std::complex<double>> &x, int incx, \ 1624 std::complex<double> beta, DeviceMemory<std::complex<double>> *y, \ 1625 int incy, blas::ProfileResult *output_profile_result) override; \ 1626 bool DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha, \ 1627 const DeviceMemory<float> &x, int incx, \ 1628 const DeviceMemory<float> &y, int incy, \ 1629 DeviceMemory<float> *a, int lda) override; \ 1630 bool DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha, \ 1631 const DeviceMemory<double> &x, int incx, \ 1632 const DeviceMemory<double> &y, int incy, \ 1633 DeviceMemory<double> *a, int lda) override; \ 1634 bool DoBlasGerc(Stream *stream, uint64 m, uint64 n, \ 1635 std::complex<float> alpha, \ 1636 const DeviceMemory<std::complex<float>> &x, int incx, \ 1637 const DeviceMemory<std::complex<float>> &y, int incy, \ 1638 DeviceMemory<std::complex<float>> *a, int lda) override; \ 1639 bool DoBlasGerc(Stream *stream, uint64 m, uint64 n, \ 1640 std::complex<double> alpha, \ 1641 const DeviceMemory<std::complex<double>> &x, int incx, \ 1642 const DeviceMemory<std::complex<double>> &y, int incy, \ 1643 DeviceMemory<std::complex<double>> *a, int lda) override; \ 1644 bool DoBlasGeru(Stream *stream, uint64 m, uint64 n, \ 1645 std::complex<float> alpha, \ 1646 const DeviceMemory<std::complex<float>> &x, int incx, \ 1647 const DeviceMemory<std::complex<float>> &y, int incy, \ 1648 DeviceMemory<std::complex<float>> *a, int lda) override; \ 1649 bool DoBlasGeru(Stream *stream, uint64 m, uint64 n, \ 1650 std::complex<double> alpha, \ 1651 const DeviceMemory<std::complex<double>> &x, int incx, \ 1652 const DeviceMemory<std::complex<double>> &y, int incy, \ 1653 DeviceMemory<std::complex<double>> *a, int lda) override; \ 1654 bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k, \ 1655 std::complex<float> alpha, \ 1656 const DeviceMemory<std::complex<float>> &a, int lda, \ 1657 const DeviceMemory<std::complex<float>> &x, int incx, \ 1658 std::complex<float> beta, \ 1659 DeviceMemory<std::complex<float>> *y, int incy) override; \ 1660 bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k, \ 1661 std::complex<double> alpha, \ 1662 const DeviceMemory<std::complex<double>> &a, int lda, \ 1663 const DeviceMemory<std::complex<double>> &x, int incx, \ 1664 std::complex<double> beta, \ 1665 DeviceMemory<std::complex<double>> *y, int incy) override; \ 1666 bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1667 std::complex<float> alpha, \ 1668 const DeviceMemory<std::complex<float>> &a, int lda, \ 1669 const DeviceMemory<std::complex<float>> &x, int incx, \ 1670 std::complex<float> beta, \ 1671 DeviceMemory<std::complex<float>> *y, int incy) override; \ 1672 bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1673 std::complex<double> alpha, \ 1674 const DeviceMemory<std::complex<double>> &a, int lda, \ 1675 const DeviceMemory<std::complex<double>> &x, int incx, \ 1676 std::complex<double> beta, \ 1677 DeviceMemory<std::complex<double>> *y, int incy) override; \ 1678 bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \ 1679 const DeviceMemory<std::complex<float>> &x, int incx, \ 1680 DeviceMemory<std::complex<float>> *a, int lda) override; \ 1681 bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1682 double alpha, const DeviceMemory<std::complex<double>> &x, \ 1683 int incx, DeviceMemory<std::complex<double>> *a, int lda) \ 1684 override; \ 1685 bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1686 std::complex<float> alpha, \ 1687 const DeviceMemory<std::complex<float>> &x, int incx, \ 1688 const DeviceMemory<std::complex<float>> &y, int incy, \ 1689 DeviceMemory<std::complex<float>> *a, int lda) override; \ 1690 bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1691 std::complex<double> alpha, \ 1692 const DeviceMemory<std::complex<double>> &x, int incx, \ 1693 const DeviceMemory<std::complex<double>> &y, int incy, \ 1694 DeviceMemory<std::complex<double>> *a, int lda) override; \ 1695 bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1696 std::complex<float> alpha, \ 1697 const DeviceMemory<std::complex<float>> &ap, \ 1698 const DeviceMemory<std::complex<float>> &x, int incx, \ 1699 std::complex<float> beta, \ 1700 DeviceMemory<std::complex<float>> *y, int incy) override; \ 1701 bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1702 std::complex<double> alpha, \ 1703 const DeviceMemory<std::complex<double>> &ap, \ 1704 const DeviceMemory<std::complex<double>> &x, int incx, \ 1705 std::complex<double> beta, \ 1706 DeviceMemory<std::complex<double>> *y, int incy) override; \ 1707 bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \ 1708 const DeviceMemory<std::complex<float>> &x, int incx, \ 1709 DeviceMemory<std::complex<float>> *ap) override; \ 1710 bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1711 double alpha, const DeviceMemory<std::complex<double>> &x, \ 1712 int incx, DeviceMemory<std::complex<double>> *ap) override; \ 1713 bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1714 std::complex<float> alpha, \ 1715 const DeviceMemory<std::complex<float>> &x, int incx, \ 1716 const DeviceMemory<std::complex<float>> &y, int incy, \ 1717 DeviceMemory<std::complex<float>> *ap) override; \ 1718 bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1719 std::complex<double> alpha, \ 1720 const DeviceMemory<std::complex<double>> &x, int incx, \ 1721 const DeviceMemory<std::complex<double>> &y, int incy, \ 1722 DeviceMemory<std::complex<double>> *ap) override; \ 1723 bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k, \ 1724 float alpha, const DeviceMemory<float> &a, int lda, \ 1725 const DeviceMemory<float> &x, int incx, float beta, \ 1726 DeviceMemory<float> *y, int incy) override; \ 1727 bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k, \ 1728 double alpha, const DeviceMemory<double> &a, int lda, \ 1729 const DeviceMemory<double> &x, int incx, double beta, \ 1730 DeviceMemory<double> *y, int incy) override; \ 1731 bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1732 float alpha, const DeviceMemory<float> &ap, \ 1733 const DeviceMemory<float> &x, int incx, float beta, \ 1734 DeviceMemory<float> *y, int incy) override; \ 1735 bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1736 double alpha, const DeviceMemory<double> &ap, \ 1737 const DeviceMemory<double> &x, int incx, double beta, \ 1738 DeviceMemory<double> *y, int incy) override; \ 1739 bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \ 1740 const DeviceMemory<float> &x, int incx, \ 1741 DeviceMemory<float> *ap) override; \ 1742 bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1743 double alpha, const DeviceMemory<double> &x, int incx, \ 1744 DeviceMemory<double> *ap) override; \ 1745 bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1746 float alpha, const DeviceMemory<float> &x, int incx, \ 1747 const DeviceMemory<float> &y, int incy, \ 1748 DeviceMemory<float> *ap) override; \ 1749 bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1750 double alpha, const DeviceMemory<double> &x, int incx, \ 1751 const DeviceMemory<double> &y, int incy, \ 1752 DeviceMemory<double> *ap) override; \ 1753 bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1754 float alpha, const DeviceMemory<float> &a, int lda, \ 1755 const DeviceMemory<float> &x, int incx, float beta, \ 1756 DeviceMemory<float> *y, int incy) override; \ 1757 bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1758 double alpha, const DeviceMemory<double> &a, int lda, \ 1759 const DeviceMemory<double> &x, int incx, double beta, \ 1760 DeviceMemory<double> *y, int incy) override; \ 1761 bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \ 1762 const DeviceMemory<float> &x, int incx, \ 1763 DeviceMemory<float> *a, int lda) override; \ 1764 bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1765 double alpha, const DeviceMemory<double> &x, int incx, \ 1766 DeviceMemory<double> *a, int lda) override; \ 1767 bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1768 float alpha, const DeviceMemory<float> &x, int incx, \ 1769 const DeviceMemory<float> &y, int incy, \ 1770 DeviceMemory<float> *a, int lda) override; \ 1771 bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1772 double alpha, const DeviceMemory<double> &x, int incx, \ 1773 const DeviceMemory<double> &y, int incy, \ 1774 DeviceMemory<double> *a, int lda) override; \ 1775 bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, \ 1776 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1777 uint64 k, const DeviceMemory<float> &a, int lda, \ 1778 DeviceMemory<float> *x, int incx) override; \ 1779 bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, \ 1780 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1781 uint64 k, const DeviceMemory<double> &a, int lda, \ 1782 DeviceMemory<double> *x, int incx) override; \ 1783 bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, \ 1784 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1785 uint64 k, const DeviceMemory<std::complex<float>> &a, \ 1786 int lda, DeviceMemory<std::complex<float>> *x, int incx) \ 1787 override; \ 1788 bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, \ 1789 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1790 uint64 k, const DeviceMemory<std::complex<double>> &a, \ 1791 int lda, DeviceMemory<std::complex<double>> *x, int incx) \ 1792 override; \ 1793 bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, \ 1794 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1795 uint64 k, const DeviceMemory<float> &a, int lda, \ 1796 DeviceMemory<float> *x, int incx) override; \ 1797 bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, \ 1798 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1799 uint64 k, const DeviceMemory<double> &a, int lda, \ 1800 DeviceMemory<double> *x, int incx) override; \ 1801 bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, \ 1802 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1803 uint64 k, const DeviceMemory<std::complex<float>> &a, \ 1804 int lda, DeviceMemory<std::complex<float>> *x, int incx) \ 1805 override; \ 1806 bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, \ 1807 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1808 uint64 k, const DeviceMemory<std::complex<double>> &a, \ 1809 int lda, DeviceMemory<std::complex<double>> *x, int incx) \ 1810 override; \ 1811 bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, \ 1812 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1813 const DeviceMemory<float> &ap, DeviceMemory<float> *x, \ 1814 int incx) override; \ 1815 bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, \ 1816 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1817 const DeviceMemory<double> &ap, DeviceMemory<double> *x, \ 1818 int incx) override; \ 1819 bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, \ 1820 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1821 const DeviceMemory<std::complex<float>> &ap, \ 1822 DeviceMemory<std::complex<float>> *x, int incx) override; \ 1823 bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, \ 1824 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1825 const DeviceMemory<std::complex<double>> &ap, \ 1826 DeviceMemory<std::complex<double>> *x, int incx) override; \ 1827 bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, \ 1828 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1829 const DeviceMemory<float> &ap, DeviceMemory<float> *x, \ 1830 int incx) override; \ 1831 bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, \ 1832 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1833 const DeviceMemory<double> &ap, DeviceMemory<double> *x, \ 1834 int incx) override; \ 1835 bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, \ 1836 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1837 const DeviceMemory<std::complex<float>> &ap, \ 1838 DeviceMemory<std::complex<float>> *x, int incx) override; \ 1839 bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, \ 1840 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1841 const DeviceMemory<std::complex<double>> &ap, \ 1842 DeviceMemory<std::complex<double>> *x, int incx) override; \ 1843 bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, \ 1844 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1845 const DeviceMemory<float> &a, int lda, \ 1846 DeviceMemory<float> *x, int incx) override; \ 1847 bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, \ 1848 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1849 const DeviceMemory<double> &a, int lda, \ 1850 DeviceMemory<double> *x, int incx) override; \ 1851 bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, \ 1852 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1853 const DeviceMemory<std::complex<float>> &a, int lda, \ 1854 DeviceMemory<std::complex<float>> *x, int incx) override; \ 1855 bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, \ 1856 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1857 const DeviceMemory<std::complex<double>> &a, int lda, \ 1858 DeviceMemory<std::complex<double>> *x, int incx) override; \ 1859 bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, \ 1860 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1861 const DeviceMemory<float> &a, int lda, \ 1862 DeviceMemory<float> *x, int incx) override; \ 1863 bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, \ 1864 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1865 const DeviceMemory<double> &a, int lda, \ 1866 DeviceMemory<double> *x, int incx) override; \ 1867 bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, \ 1868 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1869 const DeviceMemory<std::complex<float>> &a, int lda, \ 1870 DeviceMemory<std::complex<float>> *x, int incx) override; \ 1871 bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, \ 1872 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1873 const DeviceMemory<std::complex<double>> &a, int lda, \ 1874 DeviceMemory<std::complex<double>> *x, int incx) override; \ 1875 bool DoBlasGemm(Stream *stream, blas::Transpose transa, \ 1876 blas::Transpose transb, uint64 m, uint64 n, uint64 k, \ 1877 float alpha, const DeviceMemory<Eigen::half> &a, int lda, \ 1878 const DeviceMemory<Eigen::half> &b, int ldb, float beta, \ 1879 DeviceMemory<Eigen::half> *c, int ldc) override; \ 1880 bool DoBlasGemm(Stream *stream, blas::Transpose transa, \ 1881 blas::Transpose transb, uint64 m, uint64 n, uint64 k, \ 1882 float alpha, const DeviceMemory<float> &a, int lda, \ 1883 const DeviceMemory<float> &b, int ldb, float beta, \ 1884 DeviceMemory<float> *c, int ldc) override; \ 1885 bool DoBlasGemm(Stream *stream, blas::Transpose transa, \ 1886 blas::Transpose transb, uint64 m, uint64 n, uint64 k, \ 1887 double alpha, const DeviceMemory<double> &a, int lda, \ 1888 const DeviceMemory<double> &b, int ldb, double beta, \ 1889 DeviceMemory<double> *c, int ldc) override; \ 1890 bool DoBlasGemm(Stream *stream, blas::Transpose transa, \ 1891 blas::Transpose transb, uint64 m, uint64 n, uint64 k, \ 1892 std::complex<float> alpha, \ 1893 const DeviceMemory<std::complex<float>> &a, int lda, \ 1894 const DeviceMemory<std::complex<float>> &b, int ldb, \ 1895 std::complex<float> beta, \ 1896 DeviceMemory<std::complex<float>> *c, int ldc) override; \ 1897 bool DoBlasGemm(Stream *stream, blas::Transpose transa, \ 1898 blas::Transpose transb, uint64 m, uint64 n, uint64 k, \ 1899 std::complex<double> alpha, \ 1900 const DeviceMemory<std::complex<double>> &a, int lda, \ 1901 const DeviceMemory<std::complex<double>> &b, int ldb, \ 1902 std::complex<double> beta, \ 1903 DeviceMemory<std::complex<double>> *c, int ldc) override; \ 1904 bool DoBlasGemmWithProfiling( \ 1905 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 1906 uint64 m, uint64 n, uint64 k, float alpha, \ 1907 const DeviceMemory<Eigen::half> &a, int lda, \ 1908 const DeviceMemory<Eigen::half> &b, int ldb, float beta, \ 1909 DeviceMemory<Eigen::half> *c, int ldc, \ 1910 blas::ProfileResult *output_profile_result) override; \ 1911 bool DoBlasGemmWithProfiling( \ 1912 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 1913 uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, \ 1914 int lda, const DeviceMemory<float> &b, int ldb, float beta, \ 1915 DeviceMemory<float> *c, int ldc, \ 1916 blas::ProfileResult *output_profile_result) override; \ 1917 bool DoBlasGemmWithProfiling( \ 1918 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 1919 uint64 m, uint64 n, uint64 k, double alpha, \ 1920 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b, \ 1921 int ldb, double beta, DeviceMemory<double> *c, int ldc, \ 1922 blas::ProfileResult *output_profile_result) override; \ 1923 bool DoBlasGemmWithProfiling( \ 1924 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 1925 uint64 m, uint64 n, uint64 k, std::complex<float> alpha, \ 1926 const DeviceMemory<std::complex<float>> &a, int lda, \ 1927 const DeviceMemory<std::complex<float>> &b, int ldb, \ 1928 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, \ 1929 blas::ProfileResult *output_profile_result) override; \ 1930 bool DoBlasGemmWithProfiling( \ 1931 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 1932 uint64 m, uint64 n, uint64 k, std::complex<double> alpha, \ 1933 const DeviceMemory<std::complex<double>> &a, int lda, \ 1934 const DeviceMemory<std::complex<double>> &b, int ldb, \ 1935 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, \ 1936 int ldc, blas::ProfileResult *output_profile_result) override; \ 1937 bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms) \ 1938 override; \ 1939 bool DoBlasGemmWithAlgorithm( \ 1940 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 1941 uint64 m, uint64 n, uint64 k, const HostOrDeviceScalar<int> &alpha, \ 1942 const DeviceMemory<int8> &a, int lda, const DeviceMemory<int8> &b, \ 1943 int ldb, const HostOrDeviceScalar<int> &beta, DeviceMemory<int> *c, \ 1944 int ldc, blas::ComputationType computation_type, \ 1945 blas::AlgorithmType algorithm, \ 1946 blas::ProfileResult *output_profile_result) override; \ 1947 bool DoBlasGemmWithAlgorithm( \ 1948 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 1949 uint64 m, uint64 n, uint64 k, \ 1950 const HostOrDeviceScalar<Eigen::half> &alpha, \ 1951 const DeviceMemory<Eigen::half> &a, int lda, \ 1952 const DeviceMemory<Eigen::half> &b, int ldb, \ 1953 const HostOrDeviceScalar<Eigen::half> &beta, \ 1954 DeviceMemory<Eigen::half> *c, int ldc, \ 1955 blas::ComputationType computation_type, blas::AlgorithmType algorithm, \ 1956 blas::ProfileResult *output_profile_result) override; \ 1957 bool DoBlasGemmWithAlgorithm( \ 1958 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 1959 uint64 m, uint64 n, uint64 k, const HostOrDeviceScalar<float> &alpha, \ 1960 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b, \ 1961 int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c, \ 1962 int ldc, blas::ComputationType computation_type, \ 1963 blas::AlgorithmType algorithm, \ 1964 blas::ProfileResult *output_profile_result) override; \ 1965 bool DoBlasGemmWithAlgorithm( \ 1966 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 1967 uint64 m, uint64 n, uint64 k, const HostOrDeviceScalar<double> &alpha, \ 1968 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b, \ 1969 int ldb, const HostOrDeviceScalar<double> &beta, \ 1970 DeviceMemory<double> *c, int ldc, \ 1971 blas::ComputationType computation_type, blas::AlgorithmType algorithm, \ 1972 blas::ProfileResult *output_profile_result) override; \ 1973 bool DoBlasGemmWithAlgorithm( \ 1974 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 1975 uint64 m, uint64 n, uint64 k, \ 1976 const HostOrDeviceScalar<std::complex<float>> &alpha, \ 1977 const DeviceMemory<std::complex<float>> &a, int lda, \ 1978 const DeviceMemory<std::complex<float>> &b, int ldb, \ 1979 const HostOrDeviceScalar<std::complex<float>> &beta, \ 1980 DeviceMemory<std::complex<float>> *c, int ldc, \ 1981 blas::ComputationType computation_type, blas::AlgorithmType algorithm, \ 1982 blas::ProfileResult *output_profile_result) override; \ 1983 bool DoBlasGemmWithAlgorithm( \ 1984 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 1985 uint64 m, uint64 n, uint64 k, \ 1986 const HostOrDeviceScalar<std::complex<double>> &alpha, \ 1987 const DeviceMemory<std::complex<double>> &a, int lda, \ 1988 const DeviceMemory<std::complex<double>> &b, int ldb, \ 1989 const HostOrDeviceScalar<std::complex<double>> &beta, \ 1990 DeviceMemory<std::complex<double>> *c, int ldc, \ 1991 blas::ComputationType computation_type, blas::AlgorithmType algorithm, \ 1992 blas::ProfileResult *output_profile_result) override; \ 1993 bool DoBlasGemmBatched( \ 1994 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 1995 uint64 m, uint64 n, uint64 k, float alpha, \ 1996 const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda, \ 1997 const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, \ 1998 float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, \ 1999 int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \ 2000 bool DoBlasGemmBatched( \ 2001 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2002 uint64 m, uint64 n, uint64 k, float alpha, \ 2003 const port::ArraySlice<DeviceMemory<float> *> &a, int lda, \ 2004 const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, float beta, \ 2005 const port::ArraySlice<DeviceMemory<float> *> &c, int ldc, \ 2006 int batch_count, ScratchAllocator *scratch_allocator) override; \ 2007 bool DoBlasGemmBatched( \ 2008 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2009 uint64 m, uint64 n, uint64 k, double alpha, \ 2010 const port::ArraySlice<DeviceMemory<double> *> &a, int lda, \ 2011 const port::ArraySlice<DeviceMemory<double> *> &b, int ldb, double beta, \ 2012 const port::ArraySlice<DeviceMemory<double> *> &c, int ldc, \ 2013 int batch_count, ScratchAllocator *scratch_allocator) override; \ 2014 bool DoBlasGemmBatched( \ 2015 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2016 uint64 m, uint64 n, uint64 k, std::complex<float> alpha, \ 2017 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda, \ 2018 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb, \ 2019 std::complex<float> beta, \ 2020 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc, \ 2021 int batch_count, ScratchAllocator *scratch_allocator) override; \ 2022 bool DoBlasGemmBatched( \ 2023 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2024 uint64 m, uint64 n, uint64 k, std::complex<double> alpha, \ 2025 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, \ 2026 int lda, \ 2027 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, \ 2028 int ldb, std::complex<double> beta, \ 2029 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, \ 2030 int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \ 2031 bool DoBlasGemmStridedBatched( \ 2032 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2033 uint64 m, uint64 n, uint64 k, float alpha, \ 2034 const DeviceMemory<Eigen::half> &a, int lda, int64 stride_a, \ 2035 const DeviceMemory<Eigen::half> &b, int ldb, int64 stride_b, float beta, \ 2036 DeviceMemory<Eigen::half> *c, int ldc, int64 stride_c, int batch_count); \ 2037 bool DoBlasGemmStridedBatched( \ 2038 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2039 uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, \ 2040 int lda, int64 stride_a, const DeviceMemory<float> &b, int ldb, \ 2041 int64 stride_b, float beta, DeviceMemory<float> *c, int ldc, \ 2042 int64 stride_c, int batch_count); \ 2043 bool DoBlasGemmStridedBatched( \ 2044 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2045 uint64 m, uint64 n, uint64 k, double alpha, \ 2046 const DeviceMemory<double> &a, int lda, int64 stride_a, \ 2047 const DeviceMemory<double> &b, int ldb, int64 stride_b, double beta, \ 2048 DeviceMemory<double> *c, int ldc, int64 stride_c, int batch_count); \ 2049 bool DoBlasGemmStridedBatched( \ 2050 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2051 uint64 m, uint64 n, uint64 k, std::complex<float> alpha, \ 2052 const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a, \ 2053 const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b, \ 2054 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, \ 2055 int64 stride_c, int batch_count); \ 2056 bool DoBlasGemmStridedBatched( \ 2057 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2058 uint64 m, uint64 n, uint64 k, std::complex<double> alpha, \ 2059 const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a, \ 2060 const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b, \ 2061 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, \ 2062 int ldc, int64 stride_c, int batch_count); \ 2063 bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2064 uint64 m, uint64 n, std::complex<float> alpha, \ 2065 const DeviceMemory<std::complex<float>> &a, int lda, \ 2066 const DeviceMemory<std::complex<float>> &b, int ldb, \ 2067 std::complex<float> beta, \ 2068 DeviceMemory<std::complex<float>> *c, int ldc) override; \ 2069 bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2070 uint64 m, uint64 n, std::complex<double> alpha, \ 2071 const DeviceMemory<std::complex<double>> &a, int lda, \ 2072 const DeviceMemory<std::complex<double>> &b, int ldb, \ 2073 std::complex<double> beta, \ 2074 DeviceMemory<std::complex<double>> *c, int ldc) override; \ 2075 bool DoBlasHerk(Stream *stream, blas::UpperLower uplo, \ 2076 blas::Transpose trans, uint64 n, uint64 k, float alpha, \ 2077 const DeviceMemory<std::complex<float>> &a, int lda, \ 2078 float beta, DeviceMemory<std::complex<float>> *c, int ldc) \ 2079 override; \ 2080 bool DoBlasHerk(Stream *stream, blas::UpperLower uplo, \ 2081 blas::Transpose trans, uint64 n, uint64 k, double alpha, \ 2082 const DeviceMemory<std::complex<double>> &a, int lda, \ 2083 double beta, DeviceMemory<std::complex<double>> *c, int ldc) \ 2084 override; \ 2085 bool DoBlasHer2k( \ 2086 Stream *stream, blas::UpperLower uplo, blas::Transpose trans, uint64 n, \ 2087 uint64 k, std::complex<float> alpha, \ 2088 const DeviceMemory<std::complex<float>> &a, int lda, \ 2089 const DeviceMemory<std::complex<float>> &b, int ldb, float beta, \ 2090 DeviceMemory<std::complex<float>> *c, int ldc) override; \ 2091 bool DoBlasHer2k( \ 2092 Stream *stream, blas::UpperLower uplo, blas::Transpose trans, uint64 n, \ 2093 uint64 k, std::complex<double> alpha, \ 2094 const DeviceMemory<std::complex<double>> &a, int lda, \ 2095 const DeviceMemory<std::complex<double>> &b, int ldb, double beta, \ 2096 DeviceMemory<std::complex<double>> *c, int ldc) override; \ 2097 bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2098 uint64 m, uint64 n, float alpha, \ 2099 const DeviceMemory<float> &a, int lda, \ 2100 const DeviceMemory<float> &b, int ldb, float beta, \ 2101 DeviceMemory<float> *c, int ldc) override; \ 2102 bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2103 uint64 m, uint64 n, double alpha, \ 2104 const DeviceMemory<double> &a, int lda, \ 2105 const DeviceMemory<double> &b, int ldb, double beta, \ 2106 DeviceMemory<double> *c, int ldc) override; \ 2107 bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2108 uint64 m, uint64 n, std::complex<float> alpha, \ 2109 const DeviceMemory<std::complex<float>> &a, int lda, \ 2110 const DeviceMemory<std::complex<float>> &b, int ldb, \ 2111 std::complex<float> beta, \ 2112 DeviceMemory<std::complex<float>> *c, int ldc) override; \ 2113 bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2114 uint64 m, uint64 n, std::complex<double> alpha, \ 2115 const DeviceMemory<std::complex<double>> &a, int lda, \ 2116 const DeviceMemory<std::complex<double>> &b, int ldb, \ 2117 std::complex<double> beta, \ 2118 DeviceMemory<std::complex<double>> *c, int ldc) override; \ 2119 bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, \ 2120 blas::Transpose trans, uint64 n, uint64 k, float alpha, \ 2121 const DeviceMemory<float> &a, int lda, float beta, \ 2122 DeviceMemory<float> *c, int ldc) override; \ 2123 bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, \ 2124 blas::Transpose trans, uint64 n, uint64 k, double alpha, \ 2125 const DeviceMemory<double> &a, int lda, double beta, \ 2126 DeviceMemory<double> *c, int ldc) override; \ 2127 bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, \ 2128 blas::Transpose trans, uint64 n, uint64 k, \ 2129 std::complex<float> alpha, \ 2130 const DeviceMemory<std::complex<float>> &a, int lda, \ 2131 std::complex<float> beta, \ 2132 DeviceMemory<std::complex<float>> *c, int ldc) override; \ 2133 bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, \ 2134 blas::Transpose trans, uint64 n, uint64 k, \ 2135 std::complex<double> alpha, \ 2136 const DeviceMemory<std::complex<double>> &a, int lda, \ 2137 std::complex<double> beta, \ 2138 DeviceMemory<std::complex<double>> *c, int ldc) override; \ 2139 bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, \ 2140 blas::Transpose trans, uint64 n, uint64 k, float alpha, \ 2141 const DeviceMemory<float> &a, int lda, \ 2142 const DeviceMemory<float> &b, int ldb, float beta, \ 2143 DeviceMemory<float> *c, int ldc) override; \ 2144 bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, \ 2145 blas::Transpose trans, uint64 n, uint64 k, double alpha, \ 2146 const DeviceMemory<double> &a, int lda, \ 2147 const DeviceMemory<double> &b, int ldb, double beta, \ 2148 DeviceMemory<double> *c, int ldc) override; \ 2149 bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, \ 2150 blas::Transpose trans, uint64 n, uint64 k, \ 2151 std::complex<float> alpha, \ 2152 const DeviceMemory<std::complex<float>> &a, int lda, \ 2153 const DeviceMemory<std::complex<float>> &b, int ldb, \ 2154 std::complex<float> beta, \ 2155 DeviceMemory<std::complex<float>> *c, int ldc) override; \ 2156 bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, \ 2157 blas::Transpose trans, uint64 n, uint64 k, \ 2158 std::complex<double> alpha, \ 2159 const DeviceMemory<std::complex<double>> &a, int lda, \ 2160 const DeviceMemory<std::complex<double>> &b, int ldb, \ 2161 std::complex<double> beta, \ 2162 DeviceMemory<std::complex<double>> *c, int ldc) override; \ 2163 bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2164 blas::Transpose transa, blas::Diagonal diag, uint64 m, \ 2165 uint64 n, float alpha, const DeviceMemory<float> &a, \ 2166 int lda, DeviceMemory<float> *b, int ldb) override; \ 2167 bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2168 blas::Transpose transa, blas::Diagonal diag, uint64 m, \ 2169 uint64 n, double alpha, const DeviceMemory<double> &a, \ 2170 int lda, DeviceMemory<double> *b, int ldb) override; \ 2171 bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2172 blas::Transpose transa, blas::Diagonal diag, uint64 m, \ 2173 uint64 n, std::complex<float> alpha, \ 2174 const DeviceMemory<std::complex<float>> &a, int lda, \ 2175 DeviceMemory<std::complex<float>> *b, int ldb) override; \ 2176 bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2177 blas::Transpose transa, blas::Diagonal diag, uint64 m, \ 2178 uint64 n, std::complex<double> alpha, \ 2179 const DeviceMemory<std::complex<double>> &a, int lda, \ 2180 DeviceMemory<std::complex<double>> *b, int ldb) override; \ 2181 bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2182 blas::Transpose transa, blas::Diagonal diag, uint64 m, \ 2183 uint64 n, float alpha, const DeviceMemory<float> &a, \ 2184 int lda, DeviceMemory<float> *b, int ldb) override; \ 2185 bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2186 blas::Transpose transa, blas::Diagonal diag, uint64 m, \ 2187 uint64 n, double alpha, const DeviceMemory<double> &a, \ 2188 int lda, DeviceMemory<double> *b, int ldb) override; \ 2189 bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2190 blas::Transpose transa, blas::Diagonal diag, uint64 m, \ 2191 uint64 n, std::complex<float> alpha, \ 2192 const DeviceMemory<std::complex<float>> &a, int lda, \ 2193 DeviceMemory<std::complex<float>> *b, int ldb) override; \ 2194 bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2195 blas::Transpose transa, blas::Diagonal diag, uint64 m, \ 2196 uint64 n, std::complex<double> alpha, \ 2197 const DeviceMemory<std::complex<double>> &a, int lda, \ 2198 DeviceMemory<std::complex<double>> *b, int ldb) override; \ 2199 port::Status GetVersion(string *version) override; 2200 2201 } // namespace blas 2202 } // namespace stream_executor 2203 2204 #endif // TENSORFLOW_STREAM_EXECUTOR_BLAS_H_ 2205