1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Exposes the family of BLAS routines as pre-canned high performance calls for 17 // use in conjunction with the StreamExecutor abstraction. 18 // 19 // Note that this interface is optionally supported by platforms; see 20 // StreamExecutor::SupportsBlas() for details. 21 // 22 // This abstraction makes it simple to entrain BLAS operations on GPU data into 23 // a Stream -- users typically will not use this API directly, but will use the 24 // Stream builder methods to entrain these operations "under the hood". For 25 // example: 26 // 27 // DeviceMemory<float> x = stream_exec->AllocateArray<float>(1024); 28 // DeviceMemory<float> y = stream_exec->AllocateArray<float>(1024); 29 // // ... populate x and y ... 30 // Stream stream{stream_exec}; 31 // stream 32 // .Init() 33 // .ThenBlasAxpy(1024, 5.5, x, 1, &y, 1); 34 // SE_CHECK_OK(stream.BlockHostUntilDone()); 35 // 36 // By using stream operations in this manner the user can easily intermix custom 37 // kernel launches (via StreamExecutor::ThenLaunch()) with these pre-canned BLAS 38 // routines. 39 40 #ifndef TENSORFLOW_STREAM_EXECUTOR_BLAS_H_ 41 #define TENSORFLOW_STREAM_EXECUTOR_BLAS_H_ 42 43 #include <complex> 44 #include <vector> 45 46 #include "tensorflow/stream_executor/dnn.h" // For DataType, ToDataType 47 #include "tensorflow/stream_executor/lib/array_slice.h" 48 #include "tensorflow/stream_executor/lib/statusor.h" 49 #include "tensorflow/stream_executor/platform/port.h" 50 51 namespace Eigen { 52 struct half; 53 } // namespace Eigen 54 55 namespace stream_executor { 56 57 class Stream; 58 class ScratchAllocator; 59 60 template <typename ElemT> 61 class DeviceMemory; 62 63 template <typename ElemT> 64 class HostOrDeviceScalar; 65 66 namespace blas { 67 68 // Specifies whether the input matrix will be transposed or 69 // transposed+conjugated before any BLAS operations. 70 enum class Transpose { kNoTranspose, kTranspose, kConjugateTranspose }; 71 72 // Returns a name for t. 73 std::string TransposeString(Transpose t); 74 75 // Specifies whether the upper or lower triangular part of a 76 // symmetric/Hermitian matrix is used. 77 enum class UpperLower { kUpper, kLower }; 78 79 // Returns a name for ul. 80 std::string UpperLowerString(UpperLower ul); 81 82 // Specifies whether a matrix is unit triangular. 83 enum class Diagonal { kUnit, kNonUnit }; 84 85 // Returns a name for d. 86 std::string DiagonalString(Diagonal d); 87 88 // Specifies whether a Hermitian matrix appears on the left or right in 89 // operation. 90 enum class Side { kLeft, kRight }; 91 92 // Returns a name for s. 93 std::string SideString(Side s); 94 95 // Type with which intermediate computations of a blas routine are performed. 96 // 97 // Some blas calls can perform computations with a type that's different than 98 // the type of their inputs/outputs. This lets you e.g. multiply two matrices 99 // of int8s using float32s to store the matmul's intermediate values. 100 enum class ComputationType { 101 kF16, // 16-bit floating-point 102 kF32, // 32-bit floating-point 103 kF64, // 64-bit floating-point 104 kI32, // 32-bit integer 105 kComplexF32, // Complex number comprised of two f32s. 106 kComplexF64, // Complex number comprised of two f64s. 107 // The below values are only supported for BlasLt routines (both real and 108 // complex). They use float32 for accumulation but round the input mantissas 109 // to a smaller number of bits. 110 kTF32AsF32, // 32-bit floating-point with reduced (>=10-bit) mantissa 111 kBF16AsF32, // 32-bit floating-point with reduced (7-bit) mantissa 112 }; 113 114 enum class Epilogue { 115 kDefault = 1, // No special postprocessing 116 kReLU = 2, // Apply ReLU func point-wise to the results 117 kBias = 4, // Add broadcasted bias vector to the results 118 kBiasThenReLU = kBias | kReLU, // Apply bias and then ReLU transform 119 }; 120 121 // Converts a ComputationType to a string. 122 std::string ComputationTypeString(ComputationType ty); 123 124 std::ostream &operator<<(std::ostream &os, ComputationType ty); 125 126 using dnn::DataType; 127 using dnn::ToDataType; 128 129 // Describes the type of pointers for the scaling factors alpha and beta in 130 // blaslt routines. 131 enum class PointerMode { 132 kHost, 133 kDevice, 134 }; 135 136 // Converts a ComputationType to a string. 137 std::string DataTypeString(DataType ty); 138 139 std::ostream &operator<<(std::ostream &os, DataType ty); 140 141 // Opaque identifier for an "algorithm" used by a blas routine. This functions 142 // as a hint to the blas library. 143 typedef int64 AlgorithmType; 144 constexpr AlgorithmType kDefaultAlgorithm = -1; 145 constexpr AlgorithmType kDefaultBlasGemm = -2; 146 constexpr AlgorithmType kDefaultBlasGemv = -3; 147 constexpr AlgorithmType kNoAlgorithm = -4; 148 149 // blas uses -1 to represent the default algorithm. This happens to match up 150 // with the CUBLAS_GEMM_DFALT constant, so cuda_blas.cc is using static_cast 151 // to convert from AlgorithmType to cublasGemmAlgo_t, and uses a static_assert 152 // to ensure that this assumption does not break. 153 // If another blas implementation uses a different value for the default 154 // algorithm, then it needs to convert kDefaultGemmAlgo to that value 155 // (e.g. via a function called ToWhateverGemmAlgo). 156 constexpr AlgorithmType kDefaultGemmAlgo = -1; 157 158 // Describes the result of a performance experiment, usually timing the speed of 159 // a particular AlgorithmType. 160 // 161 // If the call we were benchmarking failed (a common occurrence; not all 162 // algorithms are valid for all calls), is_valid() will be false. 163 class ProfileResult { 164 public: is_valid()165 bool is_valid() const { return is_valid_; } set_is_valid(bool val)166 void set_is_valid(bool val) { is_valid_ = val; } algorithm()167 AlgorithmType algorithm() const { return algorithm_; } set_algorithm(AlgorithmType val)168 void set_algorithm(AlgorithmType val) { algorithm_ = val; } elapsed_time_in_ms()169 float elapsed_time_in_ms() const { return elapsed_time_in_ms_; } set_elapsed_time_in_ms(float val)170 void set_elapsed_time_in_ms(float val) { elapsed_time_in_ms_ = val; } 171 172 private: 173 bool is_valid_ = false; 174 AlgorithmType algorithm_ = kDefaultAlgorithm; 175 float elapsed_time_in_ms_ = std::numeric_limits<float>::max(); 176 }; 177 178 class AlgorithmConfig { 179 public: AlgorithmConfig()180 AlgorithmConfig() : algorithm_(kDefaultAlgorithm) {} AlgorithmConfig(AlgorithmType algorithm)181 explicit AlgorithmConfig(AlgorithmType algorithm) : algorithm_(algorithm) {} algorithm()182 AlgorithmType algorithm() const { return algorithm_; } set_algorithm(AlgorithmType val)183 void set_algorithm(AlgorithmType val) { algorithm_ = val; } 184 bool operator==(const AlgorithmConfig &other) const { 185 return this->algorithm_ == other.algorithm_; 186 } 187 bool operator!=(const AlgorithmConfig &other) const { 188 return !(*this == other); 189 } 190 std::string ToString() const; 191 192 private: 193 AlgorithmType algorithm_; 194 }; 195 196 struct IBlasLtMatmulPlan { 197 // Returns the data type of the A and B (input) matrices. 198 virtual DataType ab_type() const = 0; 199 // Returns the data type of the C (input/output) matrix. 200 virtual DataType c_type() const = 0; ~IBlasLtMatmulPlanIBlasLtMatmulPlan201 virtual ~IBlasLtMatmulPlan() {} 202 }; 203 204 struct IBlasLtMatmulAlgorithm { ~IBlasLtMatmulAlgorithmIBlasLtMatmulAlgorithm205 virtual ~IBlasLtMatmulAlgorithm() {} 206 // Returns the index of the algorithm within the list returned by 207 // GetBlasLtMatmulAlgorithms. 208 virtual AlgorithmType index() const = 0; 209 // Returns the workspace size required by the algorithm in bytes. 210 virtual size_t workspace_size() const = 0; 211 }; 212 213 // Parameters for the CreateBlasLtMatmulPlan method. 214 struct BlasLtMatmulPlanParams { 215 DataType ab_type; 216 DataType c_type; 217 ComputationType computation_type; 218 PointerMode pointer_mode; 219 Epilogue epilogue; 220 Transpose transa; 221 Transpose transb; 222 uint64 m; 223 uint64 n; 224 uint64 k; 225 int64 lda; 226 int64 ldb; 227 int64 ldc; 228 int batch_count = 1; 229 int64 stride_a = 0; 230 int64 stride_b = 0; 231 int64 stride_c = 0; 232 }; 233 234 // BLAS support interface -- this can be derived from a GPU executor when the 235 // underlying platform has an BLAS library implementation available. See 236 // StreamExecutor::AsBlas(). 237 // 238 // Thread-hostile: CUDA associates a CUDA-context with a particular thread in 239 // the system. Any operation that a user attempts to perform by enqueueing BLAS 240 // operations on a thread not-associated with the CUDA-context has unknown 241 // behavior at the current time; see b/13176597 242 class BlasSupport { 243 public: ~BlasSupport()244 virtual ~BlasSupport() {} 245 246 // Computes the sum of magnitudes of the vector elements. 247 // result <- |Re x(1)| + |Im x(1)| + |Re x(2)| + |Im x(2)|+ ... + |Re x(n)| 248 // + |Im x(n)|. 249 // Note that Im x(i) = 0 for real types float/double. 250 virtual bool DoBlasAsum(Stream *stream, uint64 elem_count, 251 const DeviceMemory<float> &x, int incx, 252 DeviceMemory<float> *result) = 0; 253 virtual bool DoBlasAsum(Stream *stream, uint64 elem_count, 254 const DeviceMemory<double> &x, int incx, 255 DeviceMemory<double> *result) = 0; 256 virtual bool DoBlasAsum(Stream *stream, uint64 elem_count, 257 const DeviceMemory<std::complex<float>> &x, int incx, 258 DeviceMemory<float> *result) = 0; 259 virtual bool DoBlasAsum(Stream *stream, uint64 elem_count, 260 const DeviceMemory<std::complex<double>> &x, int incx, 261 DeviceMemory<double> *result) = 0; 262 263 // Performs a BLAS y <- ax+y operation. 264 virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha, 265 const DeviceMemory<float> &x, int incx, 266 DeviceMemory<float> *y, int incy) = 0; 267 virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha, 268 const DeviceMemory<double> &x, int incx, 269 DeviceMemory<double> *y, int incy) = 0; 270 virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count, 271 std::complex<float> alpha, 272 const DeviceMemory<std::complex<float>> &x, int incx, 273 DeviceMemory<std::complex<float>> *y, int incy) = 0; 274 virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count, 275 std::complex<double> alpha, 276 const DeviceMemory<std::complex<double>> &x, int incx, 277 DeviceMemory<std::complex<double>> *y, int incy) = 0; 278 279 // Copies vector to another vector: y <- x. 280 virtual bool DoBlasCopy(Stream *stream, uint64 elem_count, 281 const DeviceMemory<float> &x, int incx, 282 DeviceMemory<float> *y, int incy) = 0; 283 virtual bool DoBlasCopy(Stream *stream, uint64 elem_count, 284 const DeviceMemory<double> &x, int incx, 285 DeviceMemory<double> *y, int incy) = 0; 286 virtual bool DoBlasCopy(Stream *stream, uint64 elem_count, 287 const DeviceMemory<std::complex<float>> &x, int incx, 288 DeviceMemory<std::complex<float>> *y, int incy) = 0; 289 virtual bool DoBlasCopy(Stream *stream, uint64 elem_count, 290 const DeviceMemory<std::complex<double>> &x, int incx, 291 DeviceMemory<std::complex<double>> *y, int incy) = 0; 292 293 // Performs a BLAS dot product result <- x . y. 294 virtual bool DoBlasDot(Stream *stream, uint64 elem_count, 295 const DeviceMemory<float> &x, int incx, 296 const DeviceMemory<float> &y, int incy, 297 DeviceMemory<float> *result) = 0; 298 virtual bool DoBlasDot(Stream *stream, uint64 elem_count, 299 const DeviceMemory<double> &x, int incx, 300 const DeviceMemory<double> &y, int incy, 301 DeviceMemory<double> *result) = 0; 302 303 // Performs a BLAS dot product result <- conj(x) . y for complex types. 304 virtual bool DoBlasDotc(Stream *stream, uint64 elem_count, 305 const DeviceMemory<std::complex<float>> &x, int incx, 306 const DeviceMemory<std::complex<float>> &y, int incy, 307 DeviceMemory<std::complex<float>> *result) = 0; 308 virtual bool DoBlasDotc(Stream *stream, uint64 elem_count, 309 const DeviceMemory<std::complex<double>> &x, int incx, 310 const DeviceMemory<std::complex<double>> &y, int incy, 311 DeviceMemory<std::complex<double>> *result) = 0; 312 313 // Performs a BLAS dot product result <- x . y for complex types. Note that 314 // x is unconjugated in this routine. 315 virtual bool DoBlasDotu(Stream *stream, uint64 elem_count, 316 const DeviceMemory<std::complex<float>> &x, int incx, 317 const DeviceMemory<std::complex<float>> &y, int incy, 318 DeviceMemory<std::complex<float>> *result) = 0; 319 virtual bool DoBlasDotu(Stream *stream, uint64 elem_count, 320 const DeviceMemory<std::complex<double>> &x, int incx, 321 const DeviceMemory<std::complex<double>> &y, int incy, 322 DeviceMemory<std::complex<double>> *result) = 0; 323 324 // Computes the Euclidean norm of a vector: result <- ||x||. 325 // See the following link for more information of Euclidean norm: 326 // http://en.wikipedia.org/wiki/Norm_(mathematics)#Euclidean_norm 327 virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count, 328 const DeviceMemory<float> &x, int incx, 329 DeviceMemory<float> *result) = 0; 330 virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count, 331 const DeviceMemory<double> &x, int incx, 332 DeviceMemory<double> *result) = 0; 333 virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count, 334 const DeviceMemory<std::complex<float>> &x, int incx, 335 DeviceMemory<float> *result) = 0; 336 virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count, 337 const DeviceMemory<std::complex<double>> &x, int incx, 338 DeviceMemory<double> *result) = 0; 339 340 // Performs rotation of points in the plane: 341 // x(i) = c*x(i) + s*y(i) 342 // y(i) = c*y(i) - s*x(i). 343 virtual bool DoBlasRot(Stream *stream, uint64 elem_count, 344 DeviceMemory<float> *x, int incx, 345 DeviceMemory<float> *y, int incy, float c, 346 float s) = 0; 347 virtual bool DoBlasRot(Stream *stream, uint64 elem_count, 348 DeviceMemory<double> *x, int incx, 349 DeviceMemory<double> *y, int incy, double c, 350 double s) = 0; 351 virtual bool DoBlasRot(Stream *stream, uint64 elem_count, 352 DeviceMemory<std::complex<float>> *x, int incx, 353 DeviceMemory<std::complex<float>> *y, int incy, 354 float c, float s) = 0; 355 virtual bool DoBlasRot(Stream *stream, uint64 elem_count, 356 DeviceMemory<std::complex<double>> *x, int incx, 357 DeviceMemory<std::complex<double>> *y, int incy, 358 double c, double s) = 0; 359 360 // Computes the parameters for a Givens rotation. 361 // Given the Cartesian coordinates (a, b) of a point, these routines return 362 // the parameters c, s, r, and z associated with the Givens rotation. The 363 // parameters c and s define a unitary matrix such that: 364 // 365 // | c s |.| a | = | r | 366 // | -s c | | b | | 0 | 367 // 368 // The parameter z is defined such that if |a| > |b|, z is s; otherwise if 369 // c is not 0 z is 1/c; otherwise z is 1. 370 virtual bool DoBlasRotg(Stream *stream, DeviceMemory<float> *a, 371 DeviceMemory<float> *b, DeviceMemory<float> *c, 372 DeviceMemory<float> *s) = 0; 373 virtual bool DoBlasRotg(Stream *stream, DeviceMemory<double> *a, 374 DeviceMemory<double> *b, DeviceMemory<double> *c, 375 DeviceMemory<double> *s) = 0; 376 virtual bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<float>> *a, 377 DeviceMemory<std::complex<float>> *b, 378 DeviceMemory<float> *c, 379 DeviceMemory<std::complex<float>> *s) = 0; 380 virtual bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<double>> *a, 381 DeviceMemory<std::complex<double>> *b, 382 DeviceMemory<double> *c, 383 DeviceMemory<std::complex<double>> *s) = 0; 384 385 // Performs modified Givens rotation of points in the plane. 386 // Given two vectors x and y, each vector element of these vectors is replaced 387 // as follows: 388 // 389 // | x(i) | = H | x(i) | 390 // | y(i) | | y(i) | 391 // 392 // for i=1 to n, where H is a modified Givens transformation matrix whose 393 // values are stored in the param[1] through param[4] array. 394 // For more information please Google this routine. 395 virtual bool DoBlasRotm(Stream *stream, uint64 elem_count, 396 DeviceMemory<float> *x, int incx, 397 DeviceMemory<float> *y, int incy, 398 const DeviceMemory<float> ¶m) = 0; 399 virtual bool DoBlasRotm(Stream *stream, uint64 elem_count, 400 DeviceMemory<double> *x, int incx, 401 DeviceMemory<double> *y, int incy, 402 const DeviceMemory<double> ¶m) = 0; 403 404 // Computes the parameters for a modified Givens rotation. 405 // Given Cartesian coordinates (x1, y1) of an input vector, these routines 406 // compute the components of a modified Givens transformation matrix H that 407 // zeros the y-component of the resulting vector: 408 // 409 // | x1 | = H | x1 * sqrt(d1) | 410 // | 0 | | y1 * sqrt(d1) | 411 // 412 // For more information please Google this routine. 413 virtual bool DoBlasRotmg(Stream *stream, DeviceMemory<float> *d1, 414 DeviceMemory<float> *d2, DeviceMemory<float> *x1, 415 const DeviceMemory<float> &y1, 416 DeviceMemory<float> *param) = 0; 417 virtual bool DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1, 418 DeviceMemory<double> *d2, DeviceMemory<double> *x1, 419 const DeviceMemory<double> &y1, 420 DeviceMemory<double> *param) = 0; 421 422 // Computes the product of a vector by a scalar: x <- a*x. 423 virtual bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha, 424 DeviceMemory<float> *x, int incx) = 0; 425 virtual bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha, 426 DeviceMemory<double> *x, int incx) = 0; 427 virtual bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha, 428 DeviceMemory<std::complex<float>> *x, int incx) = 0; 429 virtual bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha, 430 DeviceMemory<std::complex<double>> *x, int incx) = 0; 431 virtual bool DoBlasScal(Stream *stream, uint64 elem_count, 432 std::complex<float> alpha, 433 DeviceMemory<std::complex<float>> *x, int incx) = 0; 434 virtual bool DoBlasScal(Stream *stream, uint64 elem_count, 435 std::complex<double> alpha, 436 DeviceMemory<std::complex<double>> *x, int incx) = 0; 437 438 // Swaps a vector with another vector. 439 virtual bool DoBlasSwap(Stream *stream, uint64 elem_count, 440 DeviceMemory<float> *x, int incx, 441 DeviceMemory<float> *y, int incy) = 0; 442 virtual bool DoBlasSwap(Stream *stream, uint64 elem_count, 443 DeviceMemory<double> *x, int incx, 444 DeviceMemory<double> *y, int incy) = 0; 445 virtual bool DoBlasSwap(Stream *stream, uint64 elem_count, 446 DeviceMemory<std::complex<float>> *x, int incx, 447 DeviceMemory<std::complex<float>> *y, int incy) = 0; 448 virtual bool DoBlasSwap(Stream *stream, uint64 elem_count, 449 DeviceMemory<std::complex<double>> *x, int incx, 450 DeviceMemory<std::complex<double>> *y, int incy) = 0; 451 452 // Finds the index of the element with maximum absolute value. 453 virtual bool DoBlasIamax(Stream *stream, uint64 elem_count, 454 const DeviceMemory<float> &x, int incx, 455 DeviceMemory<int> *result) = 0; 456 virtual bool DoBlasIamax(Stream *stream, uint64 elem_count, 457 const DeviceMemory<double> &x, int incx, 458 DeviceMemory<int> *result) = 0; 459 virtual bool DoBlasIamax(Stream *stream, uint64 elem_count, 460 const DeviceMemory<std::complex<float>> &x, int incx, 461 DeviceMemory<int> *result) = 0; 462 virtual bool DoBlasIamax(Stream *stream, uint64 elem_count, 463 const DeviceMemory<std::complex<double>> &x, 464 int incx, DeviceMemory<int> *result) = 0; 465 466 // Finds the index of the element with minimum absolute value. 467 virtual bool DoBlasIamin(Stream *stream, uint64 elem_count, 468 const DeviceMemory<float> &x, int incx, 469 DeviceMemory<int> *result) = 0; 470 virtual bool DoBlasIamin(Stream *stream, uint64 elem_count, 471 const DeviceMemory<double> &x, int incx, 472 DeviceMemory<int> *result) = 0; 473 virtual bool DoBlasIamin(Stream *stream, uint64 elem_count, 474 const DeviceMemory<std::complex<float>> &x, int incx, 475 DeviceMemory<int> *result) = 0; 476 virtual bool DoBlasIamin(Stream *stream, uint64 elem_count, 477 const DeviceMemory<std::complex<double>> &x, 478 int incx, DeviceMemory<int> *result) = 0; 479 480 // Computes a matrix-vector product using a general band matrix: 481 // 482 // y <- alpha * a * x + beta * y, 483 // or 484 // y <- alpha * a' * x + beta * y, 485 // or 486 // y <- alpha * conj(a') * x + beta * y, 487 // 488 // alpha and beta are scalars; a is an m-by-n general band matrix, with kl 489 // sub-diagonals and ku super-diagonals; x is a vector with 490 // n(trans==kNoTranspose)/m(otherwise) elements; 491 // y is a vector with m(trans==kNoTranspose)/n(otherwise) elements. 492 virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, 493 uint64 n, uint64 kl, uint64 ku, float alpha, 494 const DeviceMemory<float> &a, int lda, 495 const DeviceMemory<float> &x, int incx, float beta, 496 DeviceMemory<float> *y, int incy) = 0; 497 virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, 498 uint64 n, uint64 kl, uint64 ku, double alpha, 499 const DeviceMemory<double> &a, int lda, 500 const DeviceMemory<double> &x, int incx, double beta, 501 DeviceMemory<double> *y, int incy) = 0; 502 virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, 503 uint64 n, uint64 kl, uint64 ku, 504 std::complex<float> alpha, 505 const DeviceMemory<std::complex<float>> &a, int lda, 506 const DeviceMemory<std::complex<float>> &x, int incx, 507 std::complex<float> beta, 508 DeviceMemory<std::complex<float>> *y, int incy) = 0; 509 virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, 510 uint64 n, uint64 kl, uint64 ku, 511 std::complex<double> alpha, 512 const DeviceMemory<std::complex<double>> &a, int lda, 513 const DeviceMemory<std::complex<double>> &x, int incx, 514 std::complex<double> beta, 515 DeviceMemory<std::complex<double>> *y, int incy) = 0; 516 517 // Computes a matrix-vector product using a general matrix. 518 // 519 // y <- alpha * a * x + beta * y, 520 // or 521 // y <- alpha * a' * x + beta * y, 522 // or 523 // y <- alpha * conj(a') * x + beta * y, 524 // 525 // alpha and beta are scalars; a is an m-by-n general matrix; x is a vector 526 // with n(trans==kNoTranspose)/m(otherwise) elements; 527 // y is a vector with m(trans==kNoTranspose)/n(otherwise) elements. 528 virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, 529 uint64 n, float alpha, const DeviceMemory<float> &a, 530 int lda, const DeviceMemory<float> &x, int incx, 531 float beta, DeviceMemory<float> *y, int incy) = 0; 532 virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, 533 uint64 n, double alpha, const DeviceMemory<double> &a, 534 int lda, const DeviceMemory<double> &x, int incx, 535 double beta, DeviceMemory<double> *y, int incy) = 0; 536 virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, 537 uint64 n, std::complex<float> alpha, 538 const DeviceMemory<std::complex<float>> &a, int lda, 539 const DeviceMemory<std::complex<float>> &x, int incx, 540 std::complex<float> beta, 541 DeviceMemory<std::complex<float>> *y, int incy) = 0; 542 virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, 543 uint64 n, std::complex<double> alpha, 544 const DeviceMemory<std::complex<double>> &a, int lda, 545 const DeviceMemory<std::complex<double>> &x, int incx, 546 std::complex<double> beta, 547 DeviceMemory<std::complex<double>> *y, int incy) = 0; 548 549 virtual bool DoBlasGemvWithProfiling( 550 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha, 551 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x, 552 int incx, float beta, DeviceMemory<float> *y, int incy, 553 ProfileResult *output_profile_result) = 0; 554 virtual bool DoBlasGemvWithProfiling( 555 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha, 556 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x, 557 int incx, double beta, DeviceMemory<double> *y, int incy, 558 ProfileResult *output_profile_result) = 0; 559 virtual bool DoBlasGemvWithProfiling( 560 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, 561 std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a, 562 int lda, const DeviceMemory<std::complex<float>> &x, int incx, 563 std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy, 564 ProfileResult *output_profile_result) = 0; 565 virtual bool DoBlasGemvWithProfiling( 566 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, 567 std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a, 568 int lda, const DeviceMemory<std::complex<double>> &x, int incx, 569 std::complex<double> beta, DeviceMemory<std::complex<double>> *y, 570 int incy, ProfileResult *output_profile_result) = 0; 571 572 // Performs a rank-1 update of a general matrix. 573 // 574 // a <- alpha * x * y' + a, 575 // 576 // alpha is a scalar; x is an m-element vector; y is an n-element vector; a is 577 // an m-by-n general matrix. 578 virtual bool DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha, 579 const DeviceMemory<float> &x, int incx, 580 const DeviceMemory<float> &y, int incy, 581 DeviceMemory<float> *a, int lda) = 0; 582 virtual bool DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha, 583 const DeviceMemory<double> &x, int incx, 584 const DeviceMemory<double> &y, int incy, 585 DeviceMemory<double> *a, int lda) = 0; 586 587 // Performs a rank-1 update (conjugated) of a general matrix. 588 // 589 // a <- alpha * x * conj(y') + a, 590 // 591 // alpha is a scalar; x is an m-element vector; y is an n-element vector; a is 592 // an m-by-n general matrix. 593 virtual bool DoBlasGerc(Stream *stream, uint64 m, uint64 n, 594 std::complex<float> alpha, 595 const DeviceMemory<std::complex<float>> &x, int incx, 596 const DeviceMemory<std::complex<float>> &y, int incy, 597 DeviceMemory<std::complex<float>> *a, int lda) = 0; 598 virtual bool DoBlasGerc(Stream *stream, uint64 m, uint64 n, 599 std::complex<double> alpha, 600 const DeviceMemory<std::complex<double>> &x, int incx, 601 const DeviceMemory<std::complex<double>> &y, int incy, 602 DeviceMemory<std::complex<double>> *a, int lda) = 0; 603 604 // Performs a rank-1 update (unconjugated) of a general matrix. 605 // 606 // a <- alpha * x * y' + a, 607 // 608 // alpha is a scalar; x is an m-element vector; y is an n-element vector; a is 609 // an m-by-n general matrix. 610 virtual bool DoBlasGeru(Stream *stream, uint64 m, uint64 n, 611 std::complex<float> alpha, 612 const DeviceMemory<std::complex<float>> &x, int incx, 613 const DeviceMemory<std::complex<float>> &y, int incy, 614 DeviceMemory<std::complex<float>> *a, int lda) = 0; 615 virtual bool DoBlasGeru(Stream *stream, uint64 m, uint64 n, 616 std::complex<double> alpha, 617 const DeviceMemory<std::complex<double>> &x, int incx, 618 const DeviceMemory<std::complex<double>> &y, int incy, 619 DeviceMemory<std::complex<double>> *a, int lda) = 0; 620 621 // Computes a matrix-vector product using a Hermitian band matrix. 622 // 623 // y <- alpha * a * x + beta * y, 624 // 625 // alpha and beta are scalars; a is an n-by-n Hermitian band matrix, with k 626 // super-diagonals; x and y are n-element vectors. 627 virtual bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, 628 uint64 k, std::complex<float> alpha, 629 const DeviceMemory<std::complex<float>> &a, int lda, 630 const DeviceMemory<std::complex<float>> &x, int incx, 631 std::complex<float> beta, 632 DeviceMemory<std::complex<float>> *y, int incy) = 0; 633 virtual bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, 634 uint64 k, std::complex<double> alpha, 635 const DeviceMemory<std::complex<double>> &a, int lda, 636 const DeviceMemory<std::complex<double>> &x, int incx, 637 std::complex<double> beta, 638 DeviceMemory<std::complex<double>> *y, int incy) = 0; 639 640 // Computes a matrix-vector product using a Hermitian matrix. 641 // 642 // y <- alpha * a * x + beta * y, 643 // 644 // alpha and beta are scalars; a is an n-by-n Hermitian matrix; x and y are 645 // n-element vectors. 646 virtual bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n, 647 std::complex<float> alpha, 648 const DeviceMemory<std::complex<float>> &a, int lda, 649 const DeviceMemory<std::complex<float>> &x, int incx, 650 std::complex<float> beta, 651 DeviceMemory<std::complex<float>> *y, int incy) = 0; 652 virtual bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n, 653 std::complex<double> alpha, 654 const DeviceMemory<std::complex<double>> &a, int lda, 655 const DeviceMemory<std::complex<double>> &x, int incx, 656 std::complex<double> beta, 657 DeviceMemory<std::complex<double>> *y, int incy) = 0; 658 659 // Performs a rank-1 update of a Hermitian matrix. 660 // 661 // a <- alpha * x * conj(x') + a, 662 // 663 // alpha is a scalar; x is an n-element vector; a is an n-by-n Hermitian 664 // matrix. 665 virtual bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, 666 float alpha, 667 const DeviceMemory<std::complex<float>> &x, int incx, 668 DeviceMemory<std::complex<float>> *a, int lda) = 0; 669 virtual bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, 670 double alpha, 671 const DeviceMemory<std::complex<double>> &x, int incx, 672 DeviceMemory<std::complex<double>> *a, int lda) = 0; 673 674 // Performs a rank-2 update of a Hermitian matrix. 675 // 676 // a <- alpha * x * conj(x') + conj(alpha) * y * conj(x') + a, 677 // 678 // alpha is a scalar; x and y are n-element vectors; a is an n-by-n Hermitian 679 // matrix. 680 virtual bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n, 681 std::complex<float> alpha, 682 const DeviceMemory<std::complex<float>> &x, int incx, 683 const DeviceMemory<std::complex<float>> &y, int incy, 684 DeviceMemory<std::complex<float>> *a, int lda) = 0; 685 virtual bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n, 686 std::complex<double> alpha, 687 const DeviceMemory<std::complex<double>> &x, int incx, 688 const DeviceMemory<std::complex<double>> &y, int incy, 689 DeviceMemory<std::complex<double>> *a, int lda) = 0; 690 691 // Computes a matrix-vector product using a Hermitian packed matrix. 692 // 693 // y <- alpha * a * x + beta * y, 694 // 695 // alpha and beta are scalars; a is an n-by-n Hermitian matrix, supplied in 696 // packed form; x and y are n-element vectors. 697 virtual bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n, 698 std::complex<float> alpha, 699 const DeviceMemory<std::complex<float>> &ap, 700 const DeviceMemory<std::complex<float>> &x, int incx, 701 std::complex<float> beta, 702 DeviceMemory<std::complex<float>> *y, int incy) = 0; 703 virtual bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n, 704 std::complex<double> alpha, 705 const DeviceMemory<std::complex<double>> &ap, 706 const DeviceMemory<std::complex<double>> &x, int incx, 707 std::complex<double> beta, 708 DeviceMemory<std::complex<double>> *y, int incy) = 0; 709 710 // Performs a rank-1 update of a Hermitian packed matrix. 711 // 712 // a <- alpha * x * conj(x') + a, 713 // 714 // alpha is a scalar; x is an n-element vector; a is an n-by-n Hermitian 715 // matrix, supplied in packed form. 716 virtual bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, 717 float alpha, 718 const DeviceMemory<std::complex<float>> &x, int incx, 719 DeviceMemory<std::complex<float>> *ap) = 0; 720 virtual bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, 721 double alpha, 722 const DeviceMemory<std::complex<double>> &x, int incx, 723 DeviceMemory<std::complex<double>> *ap) = 0; 724 725 // Performs a rank-2 update of a Hermitian packed matrix. 726 // 727 // a <- alpha * x * conj(x') + conj(alpha) * y * conj(x') + a, 728 // 729 // alpha is a scalar; x and y are n-element vectors; a is an n-by-n Hermitian 730 // matrix, supplied in packed form. 731 virtual bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n, 732 std::complex<float> alpha, 733 const DeviceMemory<std::complex<float>> &x, int incx, 734 const DeviceMemory<std::complex<float>> &y, int incy, 735 DeviceMemory<std::complex<float>> *ap) = 0; 736 virtual bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n, 737 std::complex<double> alpha, 738 const DeviceMemory<std::complex<double>> &x, int incx, 739 const DeviceMemory<std::complex<double>> &y, int incy, 740 DeviceMemory<std::complex<double>> *ap) = 0; 741 742 // Computes a matrix-vector product using a symmetric band matrix. 743 // 744 // y <- alpha * a * x + beta * y, 745 // 746 // alpha and beta are scalars; a is an n-by-n symmetric band matrix, with k 747 // super-diagonals; x and y are n-element vectors. 748 virtual bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, 749 uint64 k, float alpha, const DeviceMemory<float> &a, 750 int lda, const DeviceMemory<float> &x, int incx, 751 float beta, DeviceMemory<float> *y, int incy) = 0; 752 virtual bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, 753 uint64 k, double alpha, const DeviceMemory<double> &a, 754 int lda, const DeviceMemory<double> &x, int incx, 755 double beta, DeviceMemory<double> *y, int incy) = 0; 756 757 // Computes a matrix-vector product using a symmetric packed matrix. 758 // 759 // y <- alpha * a * x + beta * y, 760 // 761 // alpha and beta are scalars; a is an n-by-n symmetric matrix, supplied in 762 // packed form; x and y are n-element vectors. 763 virtual bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n, 764 float alpha, const DeviceMemory<float> &ap, 765 const DeviceMemory<float> &x, int incx, float beta, 766 DeviceMemory<float> *y, int incy) = 0; 767 virtual bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n, 768 double alpha, const DeviceMemory<double> &ap, 769 const DeviceMemory<double> &x, int incx, double beta, 770 DeviceMemory<double> *y, int incy) = 0; 771 772 // Performs a rank-1 update of a symmetric packed matrix. 773 // 774 // a <- alpha * x * x' + a, 775 // 776 // alpha is a scalar; x is an n-element vector; a is an n-by-n symmetric 777 // matrix, supplied in packed form. 778 virtual bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, 779 float alpha, const DeviceMemory<float> &x, int incx, 780 DeviceMemory<float> *ap) = 0; 781 virtual bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, 782 double alpha, const DeviceMemory<double> &x, int incx, 783 DeviceMemory<double> *ap) = 0; 784 785 // Performs a rank-2 update of a symmetric packed matrix. 786 // 787 // a <- alpha * x * x' + alpha * y * x' + a, 788 // 789 // alpha is a scalar; x and y are n-element vectors; a is an n-by-n symmetric 790 // matrix, supplied in packed form. 791 virtual bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n, 792 float alpha, const DeviceMemory<float> &x, int incx, 793 const DeviceMemory<float> &y, int incy, 794 DeviceMemory<float> *ap) = 0; 795 virtual bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n, 796 double alpha, const DeviceMemory<double> &x, int incx, 797 const DeviceMemory<double> &y, int incy, 798 DeviceMemory<double> *ap) = 0; 799 800 // Computes a matrix-vector product for a symmetric matrix. 801 // 802 // y <- alpha * a * x + beta * y, 803 // 804 // alpha and beta are scalars; a is an n-by-n symmetric matrix; x and y are 805 // n-element vectors. 806 virtual bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n, 807 float alpha, const DeviceMemory<float> &a, int lda, 808 const DeviceMemory<float> &x, int incx, float beta, 809 DeviceMemory<float> *y, int incy) = 0; 810 virtual bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n, 811 double alpha, const DeviceMemory<double> &a, int lda, 812 const DeviceMemory<double> &x, int incx, double beta, 813 DeviceMemory<double> *y, int incy) = 0; 814 815 // Performs a rank-1 update of a symmetric matrix. 816 // 817 // a <- alpha * x * x' + a, 818 // 819 // alpha is a scalar; x is an n-element vector; a is an n-by-n symmetric 820 // matrix. 821 virtual bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n, 822 float alpha, const DeviceMemory<float> &x, int incx, 823 DeviceMemory<float> *a, int lda) = 0; 824 virtual bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n, 825 double alpha, const DeviceMemory<double> &x, int incx, 826 DeviceMemory<double> *a, int lda) = 0; 827 828 // Performs a rank-2 update of symmetric matrix. 829 // 830 // a <- alpha * x * x' + alpha * y * x' + a, 831 // 832 // alpha is a scalar; x and y are n-element vectors; a is an n-by-n symmetric 833 // matrix. 834 virtual bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n, 835 float alpha, const DeviceMemory<float> &x, int incx, 836 const DeviceMemory<float> &y, int incy, 837 DeviceMemory<float> *a, int lda) = 0; 838 virtual bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n, 839 double alpha, const DeviceMemory<double> &x, int incx, 840 const DeviceMemory<double> &y, int incy, 841 DeviceMemory<double> *a, int lda) = 0; 842 843 // Computes a matrix-vector product using a triangular band matrix. 844 // 845 // x <- a * x, 846 // or 847 // x <- a' * x, 848 // or 849 // x <- conj(a') * x, 850 // 851 // a is an n-by-n unit, or non-unit, upper or lower triangular band matrix, 852 // with k+1 diagonals; x is a n-element vector. 853 virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, 854 blas::Transpose trans, blas::Diagonal diag, uint64 n, 855 uint64 k, const DeviceMemory<float> &a, int lda, 856 DeviceMemory<float> *x, int incx) = 0; 857 virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, 858 blas::Transpose trans, blas::Diagonal diag, uint64 n, 859 uint64 k, const DeviceMemory<double> &a, int lda, 860 DeviceMemory<double> *x, int incx) = 0; 861 virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, 862 blas::Transpose trans, blas::Diagonal diag, uint64 n, 863 uint64 k, const DeviceMemory<std::complex<float>> &a, 864 int lda, DeviceMemory<std::complex<float>> *x, 865 int incx) = 0; 866 virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, 867 blas::Transpose trans, blas::Diagonal diag, uint64 n, 868 uint64 k, const DeviceMemory<std::complex<double>> &a, 869 int lda, DeviceMemory<std::complex<double>> *x, 870 int incx) = 0; 871 872 // Solves a system of linear equations whose coefficients are in a triangular 873 // band matrix as below: 874 // 875 // a * x = b, 876 // or 877 // a' * x = b, 878 // or 879 // conj(a') * x = b, 880 // 881 // b and x are n-element vectors; a is an n-by-n unit, or non-unit, upper or 882 // lower triangular band matrix, with k+1 diagonals. 883 virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, 884 blas::Transpose trans, blas::Diagonal diag, uint64 n, 885 uint64 k, const DeviceMemory<float> &a, int lda, 886 DeviceMemory<float> *x, int incx) = 0; 887 virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, 888 blas::Transpose trans, blas::Diagonal diag, uint64 n, 889 uint64 k, const DeviceMemory<double> &a, int lda, 890 DeviceMemory<double> *x, int incx) = 0; 891 virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, 892 blas::Transpose trans, blas::Diagonal diag, uint64 n, 893 uint64 k, const DeviceMemory<std::complex<float>> &a, 894 int lda, DeviceMemory<std::complex<float>> *x, 895 int incx) = 0; 896 virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, 897 blas::Transpose trans, blas::Diagonal diag, uint64 n, 898 uint64 k, const DeviceMemory<std::complex<double>> &a, 899 int lda, DeviceMemory<std::complex<double>> *x, 900 int incx) = 0; 901 902 // Computes a matrix-vector product using a triangular packed matrix. 903 // 904 // x <- a * x, 905 // or 906 // x <- a' * x, 907 // or 908 // x <- conj(a') * x, 909 // 910 // a is an n-by-n unit, or non-unit, upper or lower triangular matrix, 911 // supplied in packed form; x is a n-element vector. 912 virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, 913 blas::Transpose trans, blas::Diagonal diag, uint64 n, 914 const DeviceMemory<float> &ap, DeviceMemory<float> *x, 915 int incx) = 0; 916 virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, 917 blas::Transpose trans, blas::Diagonal diag, uint64 n, 918 const DeviceMemory<double> &ap, 919 DeviceMemory<double> *x, int incx) = 0; 920 virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, 921 blas::Transpose trans, blas::Diagonal diag, uint64 n, 922 const DeviceMemory<std::complex<float>> &ap, 923 DeviceMemory<std::complex<float>> *x, int incx) = 0; 924 virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, 925 blas::Transpose trans, blas::Diagonal diag, uint64 n, 926 const DeviceMemory<std::complex<double>> &ap, 927 DeviceMemory<std::complex<double>> *x, int incx) = 0; 928 929 // Solves a system of linear equations whose coefficients are in a triangular 930 // packed matrix as below: 931 // 932 // a * x = b, 933 // or 934 // a' * x = b, 935 // or 936 // conj(a') * x = b, 937 // 938 // b and x are n-element vectors; a is an n-by-n unit, or non-unit, upper or 939 // lower triangular matrix, supplied in packed form. 940 virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, 941 blas::Transpose trans, blas::Diagonal diag, uint64 n, 942 const DeviceMemory<float> &ap, DeviceMemory<float> *x, 943 int incx) = 0; 944 virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, 945 blas::Transpose trans, blas::Diagonal diag, uint64 n, 946 const DeviceMemory<double> &ap, 947 DeviceMemory<double> *x, int incx) = 0; 948 virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, 949 blas::Transpose trans, blas::Diagonal diag, uint64 n, 950 const DeviceMemory<std::complex<float>> &ap, 951 DeviceMemory<std::complex<float>> *x, int incx) = 0; 952 virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, 953 blas::Transpose trans, blas::Diagonal diag, uint64 n, 954 const DeviceMemory<std::complex<double>> &ap, 955 DeviceMemory<std::complex<double>> *x, int incx) = 0; 956 957 // Computes a matrix-vector product using a triangular matrix. 958 // 959 // x <- a * x, 960 // or 961 // x <- a' * x, 962 // or 963 // x <- conj(a') * x, 964 // 965 // a is an n-by-n unit, or non-unit, upper or lower triangular matrix; x is a 966 // n-element vector. 967 virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, 968 blas::Transpose trans, blas::Diagonal diag, uint64 n, 969 const DeviceMemory<float> &a, int lda, 970 DeviceMemory<float> *x, int incx) = 0; 971 virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, 972 blas::Transpose trans, blas::Diagonal diag, uint64 n, 973 const DeviceMemory<double> &a, int lda, 974 DeviceMemory<double> *x, int incx) = 0; 975 virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, 976 blas::Transpose trans, blas::Diagonal diag, uint64 n, 977 const DeviceMemory<std::complex<float>> &a, int lda, 978 DeviceMemory<std::complex<float>> *x, int incx) = 0; 979 virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, 980 blas::Transpose trans, blas::Diagonal diag, uint64 n, 981 const DeviceMemory<std::complex<double>> &a, int lda, 982 DeviceMemory<std::complex<double>> *x, int incx) = 0; 983 984 // Solves a system of linear equations whose coefficients are in a triangular 985 // matrix as below: 986 // 987 // a * x = b, 988 // or 989 // a' * x = b, 990 // or 991 // conj(a') * x = b, 992 // 993 // b and x are n-element vectors; a is an n-by-n unit, or non-unit, upper or 994 // lower triangular matrix. 995 virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, 996 blas::Transpose trans, blas::Diagonal diag, uint64 n, 997 const DeviceMemory<float> &a, int lda, 998 DeviceMemory<float> *x, int incx) = 0; 999 virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, 1000 blas::Transpose trans, blas::Diagonal diag, uint64 n, 1001 const DeviceMemory<double> &a, int lda, 1002 DeviceMemory<double> *x, int incx) = 0; 1003 virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, 1004 blas::Transpose trans, blas::Diagonal diag, uint64 n, 1005 const DeviceMemory<std::complex<float>> &a, int lda, 1006 DeviceMemory<std::complex<float>> *x, int incx) = 0; 1007 virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, 1008 blas::Transpose trans, blas::Diagonal diag, uint64 n, 1009 const DeviceMemory<std::complex<double>> &a, int lda, 1010 DeviceMemory<std::complex<double>> *x, int incx) = 0; 1011 1012 // Computes a matrix-matrix product with general matrices: 1013 // 1014 // c <- alpha * op(a) * op(b) + beta * c, 1015 // 1016 // op(X) is one of op(X) = X, or op(X) = X', or op(X) = conj(X'); alpha and 1017 // beta are scalars; a, b, and c are matrices; op(a) is an m-by-k matrix; 1018 // op(b) is a k-by-n matrix; c is an m-by-n matrix. 1019 // 1020 // Note: The half interface uses float precision internally; the version 1021 // that uses half precision internally is not yet supported. There is no 1022 // batched version of the half-precision interface. 1023 virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa, 1024 blas::Transpose transb, uint64 m, uint64 n, uint64 k, 1025 float alpha, const DeviceMemory<Eigen::half> &a, 1026 int lda, const DeviceMemory<Eigen::half> &b, int ldb, 1027 float beta, DeviceMemory<Eigen::half> *c, 1028 int ldc) = 0; 1029 virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa, 1030 blas::Transpose transb, uint64 m, uint64 n, uint64 k, 1031 float alpha, const DeviceMemory<float> &a, int lda, 1032 const DeviceMemory<float> &b, int ldb, float beta, 1033 DeviceMemory<float> *c, int ldc) = 0; 1034 virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa, 1035 blas::Transpose transb, uint64 m, uint64 n, uint64 k, 1036 double alpha, const DeviceMemory<double> &a, int lda, 1037 const DeviceMemory<double> &b, int ldb, double beta, 1038 DeviceMemory<double> *c, int ldc) = 0; 1039 virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa, 1040 blas::Transpose transb, uint64 m, uint64 n, uint64 k, 1041 std::complex<float> alpha, 1042 const DeviceMemory<std::complex<float>> &a, int lda, 1043 const DeviceMemory<std::complex<float>> &b, int ldb, 1044 std::complex<float> beta, 1045 DeviceMemory<std::complex<float>> *c, int ldc) = 0; 1046 virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa, 1047 blas::Transpose transb, uint64 m, uint64 n, uint64 k, 1048 std::complex<double> alpha, 1049 const DeviceMemory<std::complex<double>> &a, int lda, 1050 const DeviceMemory<std::complex<double>> &b, int ldb, 1051 std::complex<double> beta, 1052 DeviceMemory<std::complex<double>> *c, int ldc) = 0; 1053 1054 virtual bool DoBlasGemmWithProfiling( 1055 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1056 uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, 1057 int lda, const DeviceMemory<Eigen::half> &b, int ldb, float beta, 1058 DeviceMemory<Eigen::half> *c, int ldc, 1059 ProfileResult *output_profile_result) = 0; 1060 virtual bool DoBlasGemmWithProfiling( 1061 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1062 uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda, 1063 const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c, 1064 int ldc, ProfileResult *output_profile_result) = 0; 1065 virtual bool DoBlasGemmWithProfiling( 1066 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1067 uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda, 1068 const DeviceMemory<double> &b, int ldb, double beta, 1069 DeviceMemory<double> *c, int ldc, 1070 ProfileResult *output_profile_result) = 0; 1071 virtual bool DoBlasGemmWithProfiling( 1072 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1073 uint64 n, uint64 k, std::complex<float> alpha, 1074 const DeviceMemory<std::complex<float>> &a, int lda, 1075 const DeviceMemory<std::complex<float>> &b, int ldb, 1076 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, 1077 ProfileResult *output_profile_result) = 0; 1078 virtual bool DoBlasGemmWithProfiling( 1079 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1080 uint64 n, uint64 k, std::complex<double> alpha, 1081 const DeviceMemory<std::complex<double>> &a, int lda, 1082 const DeviceMemory<std::complex<double>> &b, int ldb, 1083 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, 1084 ProfileResult *output_profile_result) = 0; 1085 1086 // Gets a list of supported algorithms for DoBlasGemmWithAlgorithm. 1087 virtual bool GetBlasGemmAlgorithms( 1088 std::vector<AlgorithmType> *out_algorithms) = 0; 1089 1090 // Like DoBlasGemm, but accepts an algorithm and an compute type. 1091 // 1092 // The compute type lets you say (e.g.) that the inputs and outputs are 1093 // Eigen::halfs, but you want the internal computations to be done with 1094 // float32 precision. 1095 // 1096 // Note the subtle difference in the version that accepts Eigen:::half -- 1097 // alpha and beta have type const Eigen::half&, not float. 1098 // 1099 // If output_profile_result is not null, a failure here does not put the 1100 // stream in a failure state. Instead, success/failure is indicated by 1101 // output_profile_result->is_valid(). This lets you use this function for 1102 // choosing the best algorithm among many (some of which may fail) without 1103 // creating a new Stream for each attempt. 1104 virtual bool DoBlasGemmWithAlgorithm( 1105 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1106 uint64 n, uint64 k, const HostOrDeviceScalar<int> &alpha, 1107 const DeviceMemory<int8> &a, int lda, const DeviceMemory<int8> &b, 1108 int ldb, const HostOrDeviceScalar<int> &beta, DeviceMemory<int32> *c, 1109 int ldc, ComputationType computation_type, AlgorithmType algorithm, 1110 ProfileResult *output_profile_result) = 0; 1111 virtual bool DoBlasGemmWithAlgorithm( 1112 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1113 uint64 n, uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha, 1114 const DeviceMemory<Eigen::half> &a, int lda, 1115 const DeviceMemory<Eigen::half> &b, int ldb, 1116 const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c, 1117 int ldc, ComputationType computation_type, AlgorithmType algorithm, 1118 ProfileResult *output_profile_result) = 0; 1119 virtual bool DoBlasGemmWithAlgorithm( 1120 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1121 uint64 n, uint64 k, const HostOrDeviceScalar<float> &alpha, 1122 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b, 1123 int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c, 1124 int ldc, ComputationType computation_type, AlgorithmType algorithm, 1125 ProfileResult *output_profile_result) = 0; 1126 virtual bool DoBlasGemmWithAlgorithm( 1127 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1128 uint64 n, uint64 k, const HostOrDeviceScalar<double> &alpha, 1129 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b, 1130 int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c, 1131 int ldc, ComputationType computation_type, AlgorithmType algorithm, 1132 ProfileResult *output_profile_result) = 0; 1133 virtual bool DoBlasGemmWithAlgorithm( 1134 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1135 uint64 n, uint64 k, const HostOrDeviceScalar<std::complex<float>> &alpha, 1136 const DeviceMemory<std::complex<float>> &a, int lda, 1137 const DeviceMemory<std::complex<float>> &b, int ldb, 1138 const HostOrDeviceScalar<std::complex<float>> &beta, 1139 DeviceMemory<std::complex<float>> *c, int ldc, 1140 ComputationType computation_type, AlgorithmType algorithm, 1141 ProfileResult *output_profile_result) = 0; 1142 virtual bool DoBlasGemmWithAlgorithm( 1143 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1144 uint64 n, uint64 k, const HostOrDeviceScalar<std::complex<double>> &alpha, 1145 const DeviceMemory<std::complex<double>> &a, int lda, 1146 const DeviceMemory<std::complex<double>> &b, int ldb, 1147 const HostOrDeviceScalar<std::complex<double>> &beta, 1148 DeviceMemory<std::complex<double>> *c, int ldc, 1149 ComputationType computation_type, AlgorithmType algorithm, 1150 ProfileResult *output_profile_result) = 0; 1151 1152 // Computes a batch of matrix-matrix product with general matrices. 1153 // This is a batched version of DoBlasGemm. 1154 // The batched GEMM computes matrix product for each input/output in a, b, 1155 // and c, which contain batch_count DeviceMemory objects. 1156 virtual bool DoBlasGemmBatched( 1157 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1158 uint64 n, uint64 k, float alpha, 1159 const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda, 1160 const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, 1161 float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, 1162 int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0; 1163 virtual bool DoBlasGemmBatched( 1164 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1165 uint64 n, uint64 k, float alpha, 1166 const port::ArraySlice<DeviceMemory<float> *> &a, int lda, 1167 const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, float beta, 1168 const port::ArraySlice<DeviceMemory<float> *> &c, int ldc, 1169 int batch_count, ScratchAllocator *scratch_allocator) = 0; 1170 virtual bool DoBlasGemmBatched( 1171 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1172 uint64 n, uint64 k, double alpha, 1173 const port::ArraySlice<DeviceMemory<double> *> &a, int lda, 1174 const port::ArraySlice<DeviceMemory<double> *> &b, int ldb, double beta, 1175 const port::ArraySlice<DeviceMemory<double> *> &c, int ldc, 1176 int batch_count, ScratchAllocator *scratch_allocator) = 0; 1177 virtual bool DoBlasGemmBatched( 1178 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1179 uint64 n, uint64 k, std::complex<float> alpha, 1180 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda, 1181 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb, 1182 std::complex<float> beta, 1183 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc, 1184 int batch_count, ScratchAllocator *scratch_allocator) = 0; 1185 virtual bool DoBlasGemmBatched( 1186 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1187 uint64 n, uint64 k, std::complex<double> alpha, 1188 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda, 1189 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb, 1190 std::complex<double> beta, 1191 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc, 1192 int batch_count, ScratchAllocator *scratch_allocator) = 0; 1193 1194 // Batched gemm with strides instead of pointer arrays. 1195 virtual bool DoBlasGemmStridedBatched( 1196 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1197 uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, 1198 int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb, 1199 int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc, 1200 int64 stride_c, int batch_count) = 0; 1201 virtual bool DoBlasGemmStridedBatched( 1202 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1203 uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda, 1204 int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b, 1205 float beta, DeviceMemory<float> *c, int ldc, int64 stride_c, 1206 int batch_count) = 0; 1207 virtual bool DoBlasGemmStridedBatched( 1208 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1209 uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda, 1210 int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b, 1211 double beta, DeviceMemory<double> *c, int ldc, int64 stride_c, 1212 int batch_count) = 0; 1213 virtual bool DoBlasGemmStridedBatched( 1214 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1215 uint64 n, uint64 k, std::complex<float> alpha, 1216 const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a, 1217 const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b, 1218 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, 1219 int64 stride_c, int batch_count) = 0; 1220 virtual bool DoBlasGemmStridedBatched( 1221 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1222 uint64 n, uint64 k, std::complex<double> alpha, 1223 const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a, 1224 const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b, 1225 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, 1226 int64 stride_c, int batch_count) = 0; 1227 1228 // Computes a matrix-matrix product where one input matrix is Hermitian: 1229 // 1230 // c <- alpha * a * b + beta * c, 1231 // or 1232 // c <- alpha * b * a + beta * c, 1233 // 1234 // alpha and beta are scalars; a is a Hermitian matrix; b and c are m-by-n 1235 // matrices. 1236 virtual bool DoBlasHemm(Stream *stream, blas::Side side, 1237 blas::UpperLower uplo, uint64 m, uint64 n, 1238 std::complex<float> alpha, 1239 const DeviceMemory<std::complex<float>> &a, int lda, 1240 const DeviceMemory<std::complex<float>> &b, int ldb, 1241 std::complex<float> beta, 1242 DeviceMemory<std::complex<float>> *c, int ldc) = 0; 1243 virtual bool DoBlasHemm(Stream *stream, blas::Side side, 1244 blas::UpperLower uplo, uint64 m, uint64 n, 1245 std::complex<double> alpha, 1246 const DeviceMemory<std::complex<double>> &a, int lda, 1247 const DeviceMemory<std::complex<double>> &b, int ldb, 1248 std::complex<double> beta, 1249 DeviceMemory<std::complex<double>> *c, int ldc) = 0; 1250 1251 // Performs a Hermitian rank-k update. 1252 // 1253 // c <- alpha * a * conj(a') + beta * c, 1254 // or 1255 // c <- alpha * conj(a') * a + beta * c, 1256 // 1257 // alpha and beta are scalars; c is a n-by-n Hermitian matrix; a is an n-by-k 1258 // matrix in the first case and a k-by-n matrix in the second case. 1259 virtual bool DoBlasHerk(Stream *stream, blas::UpperLower uplo, 1260 blas::Transpose trans, uint64 n, uint64 k, 1261 float alpha, 1262 const DeviceMemory<std::complex<float>> &a, int lda, 1263 float beta, DeviceMemory<std::complex<float>> *c, 1264 int ldc) = 0; 1265 virtual bool DoBlasHerk(Stream *stream, blas::UpperLower uplo, 1266 blas::Transpose trans, uint64 n, uint64 k, 1267 double alpha, 1268 const DeviceMemory<std::complex<double>> &a, int lda, 1269 double beta, DeviceMemory<std::complex<double>> *c, 1270 int ldc) = 0; 1271 1272 // Performs a Hermitian rank-2k update. 1273 // 1274 // c <- alpha * a * conj(b') + conj(alpha) * b * conj(a') + beta * c, 1275 // or 1276 // c <- alpha * conj(b') * a + conj(alpha) * conj(a') * b + beta * c, 1277 // 1278 // alpha and beta are scalars; c is a n-by-n Hermitian matrix; a and b are 1279 // n-by-k matrices in the first case and k-by-n matrices in the second case. 1280 virtual bool DoBlasHer2k(Stream *stream, blas::UpperLower uplo, 1281 blas::Transpose trans, uint64 n, uint64 k, 1282 std::complex<float> alpha, 1283 const DeviceMemory<std::complex<float>> &a, int lda, 1284 const DeviceMemory<std::complex<float>> &b, int ldb, 1285 float beta, DeviceMemory<std::complex<float>> *c, 1286 int ldc) = 0; 1287 virtual bool DoBlasHer2k(Stream *stream, blas::UpperLower uplo, 1288 blas::Transpose trans, uint64 n, uint64 k, 1289 std::complex<double> alpha, 1290 const DeviceMemory<std::complex<double>> &a, int lda, 1291 const DeviceMemory<std::complex<double>> &b, int ldb, 1292 double beta, DeviceMemory<std::complex<double>> *c, 1293 int ldc) = 0; 1294 1295 // Computes a matrix-matrix product where one input matrix is symmetric. 1296 // 1297 // c <- alpha * a * b + beta * c, 1298 // or 1299 // c <- alpha * b * a + beta * c, 1300 // 1301 // alpha and beta are scalars; a is a symmetric matrix; b and c are m-by-n 1302 // matrices. 1303 virtual bool DoBlasSymm(Stream *stream, blas::Side side, 1304 blas::UpperLower uplo, uint64 m, uint64 n, 1305 float alpha, const DeviceMemory<float> &a, int lda, 1306 const DeviceMemory<float> &b, int ldb, float beta, 1307 DeviceMemory<float> *c, int ldc) = 0; 1308 virtual bool DoBlasSymm(Stream *stream, blas::Side side, 1309 blas::UpperLower uplo, uint64 m, uint64 n, 1310 double alpha, const DeviceMemory<double> &a, int lda, 1311 const DeviceMemory<double> &b, int ldb, double beta, 1312 DeviceMemory<double> *c, int ldc) = 0; 1313 virtual bool DoBlasSymm(Stream *stream, blas::Side side, 1314 blas::UpperLower uplo, uint64 m, uint64 n, 1315 std::complex<float> alpha, 1316 const DeviceMemory<std::complex<float>> &a, int lda, 1317 const DeviceMemory<std::complex<float>> &b, int ldb, 1318 std::complex<float> beta, 1319 DeviceMemory<std::complex<float>> *c, int ldc) = 0; 1320 virtual bool DoBlasSymm(Stream *stream, blas::Side side, 1321 blas::UpperLower uplo, uint64 m, uint64 n, 1322 std::complex<double> alpha, 1323 const DeviceMemory<std::complex<double>> &a, int lda, 1324 const DeviceMemory<std::complex<double>> &b, int ldb, 1325 std::complex<double> beta, 1326 DeviceMemory<std::complex<double>> *c, int ldc) = 0; 1327 1328 // Performs a symmetric rank-k update. 1329 // 1330 // c <- alpha * a * a' + beta * c, 1331 // or 1332 // c <- alpha * a' * a + beta * c, 1333 // 1334 // alpha and beta are scalars; c is a n-by-n symmetric matrix; a is an n-by-k 1335 // matrix in the first case and a k-by-n matrix in the second case. 1336 virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, 1337 blas::Transpose trans, uint64 n, uint64 k, 1338 float alpha, const DeviceMemory<float> &a, int lda, 1339 float beta, DeviceMemory<float> *c, int ldc) = 0; 1340 virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, 1341 blas::Transpose trans, uint64 n, uint64 k, 1342 double alpha, const DeviceMemory<double> &a, int lda, 1343 double beta, DeviceMemory<double> *c, int ldc) = 0; 1344 virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, 1345 blas::Transpose trans, uint64 n, uint64 k, 1346 std::complex<float> alpha, 1347 const DeviceMemory<std::complex<float>> &a, int lda, 1348 std::complex<float> beta, 1349 DeviceMemory<std::complex<float>> *c, int ldc) = 0; 1350 virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, 1351 blas::Transpose trans, uint64 n, uint64 k, 1352 std::complex<double> alpha, 1353 const DeviceMemory<std::complex<double>> &a, int lda, 1354 std::complex<double> beta, 1355 DeviceMemory<std::complex<double>> *c, int ldc) = 0; 1356 1357 // Performs a symmetric rank-2k update. 1358 // 1359 // c <- alpha * a * b' + alpha * b * a' + beta * c, 1360 // or 1361 // c <- alpha * b' * a + alpha * a' * b + beta * c, 1362 // 1363 // alpha and beta are scalars; c is a n-by-n symmetric matrix; a and b are 1364 // n-by-k matrices in the first case and k-by-n matrices in the second case. 1365 virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, 1366 blas::Transpose trans, uint64 n, uint64 k, 1367 float alpha, const DeviceMemory<float> &a, int lda, 1368 const DeviceMemory<float> &b, int ldb, float beta, 1369 DeviceMemory<float> *c, int ldc) = 0; 1370 virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, 1371 blas::Transpose trans, uint64 n, uint64 k, 1372 double alpha, const DeviceMemory<double> &a, int lda, 1373 const DeviceMemory<double> &b, int ldb, double beta, 1374 DeviceMemory<double> *c, int ldc) = 0; 1375 virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, 1376 blas::Transpose trans, uint64 n, uint64 k, 1377 std::complex<float> alpha, 1378 const DeviceMemory<std::complex<float>> &a, int lda, 1379 const DeviceMemory<std::complex<float>> &b, int ldb, 1380 std::complex<float> beta, 1381 DeviceMemory<std::complex<float>> *c, int ldc) = 0; 1382 virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, 1383 blas::Transpose trans, uint64 n, uint64 k, 1384 std::complex<double> alpha, 1385 const DeviceMemory<std::complex<double>> &a, int lda, 1386 const DeviceMemory<std::complex<double>> &b, int ldb, 1387 std::complex<double> beta, 1388 DeviceMemory<std::complex<double>> *c, int ldc) = 0; 1389 1390 // Computes a matrix-matrix product where one input matrix is triangular. 1391 // 1392 // b <- alpha * op(a) * b, 1393 // or 1394 // b <- alpha * b * op(a) 1395 // 1396 // alpha is a scalar; b is an m-by-n matrix; a is a unit, or non-unit, upper 1397 // or lower triangular matrix; op(a) is one of op(a) = a, or op(a) = a', or 1398 // op(a) = conj(a'). 1399 virtual bool DoBlasTrmm(Stream *stream, blas::Side side, 1400 blas::UpperLower uplo, blas::Transpose transa, 1401 blas::Diagonal diag, uint64 m, uint64 n, float alpha, 1402 const DeviceMemory<float> &a, int lda, 1403 DeviceMemory<float> *b, int ldb) = 0; 1404 virtual bool DoBlasTrmm(Stream *stream, blas::Side side, 1405 blas::UpperLower uplo, blas::Transpose transa, 1406 blas::Diagonal diag, uint64 m, uint64 n, double alpha, 1407 const DeviceMemory<double> &a, int lda, 1408 DeviceMemory<double> *b, int ldb) = 0; 1409 virtual bool DoBlasTrmm(Stream *stream, blas::Side side, 1410 blas::UpperLower uplo, blas::Transpose transa, 1411 blas::Diagonal diag, uint64 m, uint64 n, 1412 std::complex<float> alpha, 1413 const DeviceMemory<std::complex<float>> &a, int lda, 1414 DeviceMemory<std::complex<float>> *b, int ldb) = 0; 1415 virtual bool DoBlasTrmm(Stream *stream, blas::Side side, 1416 blas::UpperLower uplo, blas::Transpose transa, 1417 blas::Diagonal diag, uint64 m, uint64 n, 1418 std::complex<double> alpha, 1419 const DeviceMemory<std::complex<double>> &a, int lda, 1420 DeviceMemory<std::complex<double>> *b, int ldb) = 0; 1421 1422 // Solves a triangular matrix equation. 1423 // 1424 // op(a) * x = alpha * b, 1425 // or 1426 // x * op(a) = alpha * b 1427 // 1428 // alpha is a scalar; x and b are m-by-n matrices; a is a unit, or non-unit, 1429 // upper or lower triangular matrix; op(a) is one of op(a) = a, or op(a) = a', 1430 // or op(a) = conj(a'). 1431 virtual bool DoBlasTrsm(Stream *stream, blas::Side side, 1432 blas::UpperLower uplo, blas::Transpose transa, 1433 blas::Diagonal diag, uint64 m, uint64 n, float alpha, 1434 const DeviceMemory<float> &a, int lda, 1435 DeviceMemory<float> *b, int ldb) = 0; 1436 virtual bool DoBlasTrsm(Stream *stream, blas::Side side, 1437 blas::UpperLower uplo, blas::Transpose transa, 1438 blas::Diagonal diag, uint64 m, uint64 n, double alpha, 1439 const DeviceMemory<double> &a, int lda, 1440 DeviceMemory<double> *b, int ldb) = 0; 1441 virtual bool DoBlasTrsm(Stream *stream, blas::Side side, 1442 blas::UpperLower uplo, blas::Transpose transa, 1443 blas::Diagonal diag, uint64 m, uint64 n, 1444 std::complex<float> alpha, 1445 const DeviceMemory<std::complex<float>> &a, int lda, 1446 DeviceMemory<std::complex<float>> *b, int ldb) = 0; 1447 virtual bool DoBlasTrsm(Stream *stream, blas::Side side, 1448 blas::UpperLower uplo, blas::Transpose transa, 1449 blas::Diagonal diag, uint64 m, uint64 n, 1450 std::complex<double> alpha, 1451 const DeviceMemory<std::complex<double>> &a, int lda, 1452 DeviceMemory<std::complex<double>> *b, int ldb) = 0; 1453 1454 // Creates a backend-specific plan object for a blaslt matmul operation, which 1455 // can then be passed to DoBlasLtMatmul(). When possible, plans should be 1456 // created once and reused for multiple calls to DoBlasLtMatmul(). 1457 virtual port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>> 1458 CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams ¶ms) = 0; 1459 1460 // Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are 1461 // returned in the order of increasing estimated compute time according to an 1462 // internal heuristic. The first returned algorithm can be used as the default 1463 // algorithm if no autotuning is to be performed. 1464 virtual port::StatusOr< 1465 std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>> 1466 GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan, 1467 size_t max_workspace_size, 1468 int max_algorithm_count) = 0; 1469 1470 // Executes a blaslt matmul operation on the stream. If output_profile_result 1471 // is not nullptr, the operation is profiled, error messages are 1472 // suppressed, and output_profile_result->algorithm() is set to 1473 // algorithm->index(). If epilogue was set to kBias or kBiasThenReLU when 1474 // creating the plan, the bias argument here must refer to a valid device 1475 // vector of length equal to the number of rows in matrix c. If epilogue was 1476 // set to any other value then the bias argument here must be null. The bias 1477 // vector is broadcast across the batch dimension. 1478 // Note that the data types of a and b (c and bias) must match the ab_type 1479 // (c_type) with which the plan was created, and the data types of alpha and 1480 // beta must match the data type of c. 1481 virtual bool DoBlasLtMatmul( 1482 Stream *stream, const blas::IBlasLtMatmulPlan *plan, 1483 const HostOrDeviceScalar<void> &alpha, DeviceMemoryBase a, 1484 DeviceMemoryBase b, const HostOrDeviceScalar<void> &beta, 1485 DeviceMemoryBase c, ScratchAllocator *scratch_allocator, 1486 const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias, 1487 blas::ProfileResult *output_profile_result) = 0; 1488 1489 template <typename ABType, typename CType> 1490 bool DoBlasLtMatmul(Stream *stream, const blas::IBlasLtMatmulPlan *plan, 1491 const HostOrDeviceScalar<CType> &alpha, 1492 const DeviceMemory<ABType> &a, 1493 const DeviceMemory<ABType> &b, 1494 const HostOrDeviceScalar<CType> &beta, 1495 DeviceMemory<CType> *c, 1496 ScratchAllocator *scratch_allocator, 1497 const blas::IBlasLtMatmulAlgorithm *algorithm, 1498 const DeviceMemory<CType> &bias = {}, 1499 blas::ProfileResult *output_profile_result = nullptr) { 1500 constexpr blas::DataType ab_type = blas::ToDataType<ABType>::value; 1501 if (ab_type != plan->ab_type()) { 1502 VLOG(2) << "DoBlasLtMatmul returning false because a and b type does " 1503 "not match plan: expected " 1504 << plan->ab_type() << ", got " << ab_type; 1505 return false; 1506 } 1507 constexpr blas::DataType c_type = blas::ToDataType<CType>::value; 1508 if (c_type != plan->c_type()) { 1509 VLOG(2) << "DoBlasLtMatmul returning false because c type does " 1510 "not match plan: expected " 1511 << plan->c_type() << ", got " << c_type; 1512 return false; 1513 } 1514 return DoBlasLtMatmul(stream, plan, alpha, a, b, beta, *c, 1515 scratch_allocator, algorithm, bias, 1516 output_profile_result); 1517 } 1518 1519 virtual port::Status GetVersion(std::string *version) = 0; 1520 1521 protected: BlasSupport()1522 BlasSupport() {} 1523 1524 private: 1525 SE_DISALLOW_COPY_AND_ASSIGN(BlasSupport); 1526 }; 1527 1528 // Macro used to quickly declare overrides for abstract virtuals in the 1529 // BlasSupport base class. 1530 #define TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES \ 1531 bool DoBlasAsum(Stream *stream, uint64 elem_count, \ 1532 const DeviceMemory<float> &x, int incx, \ 1533 DeviceMemory<float> *result) override; \ 1534 bool DoBlasAsum(Stream *stream, uint64 elem_count, \ 1535 const DeviceMemory<double> &x, int incx, \ 1536 DeviceMemory<double> *result) override; \ 1537 bool DoBlasAsum(Stream *stream, uint64 elem_count, \ 1538 const DeviceMemory<std::complex<float>> &x, int incx, \ 1539 DeviceMemory<float> *result) override; \ 1540 bool DoBlasAsum(Stream *stream, uint64 elem_count, \ 1541 const DeviceMemory<std::complex<double>> &x, int incx, \ 1542 DeviceMemory<double> *result) override; \ 1543 bool DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha, \ 1544 const DeviceMemory<float> &x, int incx, \ 1545 DeviceMemory<float> *y, int incy) override; \ 1546 bool DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha, \ 1547 const DeviceMemory<double> &x, int incx, \ 1548 DeviceMemory<double> *y, int incy) override; \ 1549 bool DoBlasAxpy(Stream *stream, uint64 elem_count, \ 1550 std::complex<float> alpha, \ 1551 const DeviceMemory<std::complex<float>> &x, int incx, \ 1552 DeviceMemory<std::complex<float>> *y, int incy) override; \ 1553 bool DoBlasAxpy(Stream *stream, uint64 elem_count, \ 1554 std::complex<double> alpha, \ 1555 const DeviceMemory<std::complex<double>> &x, int incx, \ 1556 DeviceMemory<std::complex<double>> *y, int incy) override; \ 1557 bool DoBlasCopy(Stream *stream, uint64 elem_count, \ 1558 const DeviceMemory<float> &x, int incx, \ 1559 DeviceMemory<float> *y, int incy) override; \ 1560 bool DoBlasCopy(Stream *stream, uint64 elem_count, \ 1561 const DeviceMemory<double> &x, int incx, \ 1562 DeviceMemory<double> *y, int incy) override; \ 1563 bool DoBlasCopy(Stream *stream, uint64 elem_count, \ 1564 const DeviceMemory<std::complex<float>> &x, int incx, \ 1565 DeviceMemory<std::complex<float>> *y, int incy) override; \ 1566 bool DoBlasCopy(Stream *stream, uint64 elem_count, \ 1567 const DeviceMemory<std::complex<double>> &x, int incx, \ 1568 DeviceMemory<std::complex<double>> *y, int incy) override; \ 1569 bool DoBlasDot(Stream *stream, uint64 elem_count, \ 1570 const DeviceMemory<float> &x, int incx, \ 1571 const DeviceMemory<float> &y, int incy, \ 1572 DeviceMemory<float> *result) override; \ 1573 bool DoBlasDot(Stream *stream, uint64 elem_count, \ 1574 const DeviceMemory<double> &x, int incx, \ 1575 const DeviceMemory<double> &y, int incy, \ 1576 DeviceMemory<double> *result) override; \ 1577 bool DoBlasDotc(Stream *stream, uint64 elem_count, \ 1578 const DeviceMemory<std::complex<float>> &x, int incx, \ 1579 const DeviceMemory<std::complex<float>> &y, int incy, \ 1580 DeviceMemory<std::complex<float>> *result) override; \ 1581 bool DoBlasDotc(Stream *stream, uint64 elem_count, \ 1582 const DeviceMemory<std::complex<double>> &x, int incx, \ 1583 const DeviceMemory<std::complex<double>> &y, int incy, \ 1584 DeviceMemory<std::complex<double>> *result) override; \ 1585 bool DoBlasDotu(Stream *stream, uint64 elem_count, \ 1586 const DeviceMemory<std::complex<float>> &x, int incx, \ 1587 const DeviceMemory<std::complex<float>> &y, int incy, \ 1588 DeviceMemory<std::complex<float>> *result) override; \ 1589 bool DoBlasDotu(Stream *stream, uint64 elem_count, \ 1590 const DeviceMemory<std::complex<double>> &x, int incx, \ 1591 const DeviceMemory<std::complex<double>> &y, int incy, \ 1592 DeviceMemory<std::complex<double>> *result) override; \ 1593 bool DoBlasNrm2(Stream *stream, uint64 elem_count, \ 1594 const DeviceMemory<float> &x, int incx, \ 1595 DeviceMemory<float> *result) override; \ 1596 bool DoBlasNrm2(Stream *stream, uint64 elem_count, \ 1597 const DeviceMemory<double> &x, int incx, \ 1598 DeviceMemory<double> *result) override; \ 1599 bool DoBlasNrm2(Stream *stream, uint64 elem_count, \ 1600 const DeviceMemory<std::complex<float>> &x, int incx, \ 1601 DeviceMemory<float> *result) override; \ 1602 bool DoBlasNrm2(Stream *stream, uint64 elem_count, \ 1603 const DeviceMemory<std::complex<double>> &x, int incx, \ 1604 DeviceMemory<double> *result) override; \ 1605 bool DoBlasRot(Stream *stream, uint64 elem_count, DeviceMemory<float> *x, \ 1606 int incx, DeviceMemory<float> *y, int incy, float c, float s) \ 1607 override; \ 1608 bool DoBlasRot(Stream *stream, uint64 elem_count, DeviceMemory<double> *x, \ 1609 int incx, DeviceMemory<double> *y, int incy, double c, \ 1610 double s) override; \ 1611 bool DoBlasRot(Stream *stream, uint64 elem_count, \ 1612 DeviceMemory<std::complex<float>> *x, int incx, \ 1613 DeviceMemory<std::complex<float>> *y, int incy, float c, \ 1614 float s) override; \ 1615 bool DoBlasRot(Stream *stream, uint64 elem_count, \ 1616 DeviceMemory<std::complex<double>> *x, int incx, \ 1617 DeviceMemory<std::complex<double>> *y, int incy, double c, \ 1618 double s) override; \ 1619 bool DoBlasRotg(Stream *stream, DeviceMemory<float> *a, \ 1620 DeviceMemory<float> *b, DeviceMemory<float> *c, \ 1621 DeviceMemory<float> *s) override; \ 1622 bool DoBlasRotg(Stream *stream, DeviceMemory<double> *a, \ 1623 DeviceMemory<double> *b, DeviceMemory<double> *c, \ 1624 DeviceMemory<double> *s) override; \ 1625 bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<float>> *a, \ 1626 DeviceMemory<std::complex<float>> *b, \ 1627 DeviceMemory<float> *c, \ 1628 DeviceMemory<std::complex<float>> *s) override; \ 1629 bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<double>> *a, \ 1630 DeviceMemory<std::complex<double>> *b, \ 1631 DeviceMemory<double> *c, \ 1632 DeviceMemory<std::complex<double>> *s) override; \ 1633 bool DoBlasRotm(Stream *stream, uint64 elem_count, DeviceMemory<float> *x, \ 1634 int incx, DeviceMemory<float> *y, int incy, \ 1635 const DeviceMemory<float> ¶m) override; \ 1636 bool DoBlasRotm(Stream *stream, uint64 elem_count, DeviceMemory<double> *x, \ 1637 int incx, DeviceMemory<double> *y, int incy, \ 1638 const DeviceMemory<double> ¶m) override; \ 1639 bool DoBlasRotmg(Stream *stream, DeviceMemory<float> *d1, \ 1640 DeviceMemory<float> *d2, DeviceMemory<float> *x1, \ 1641 const DeviceMemory<float> &y1, DeviceMemory<float> *param) \ 1642 override; \ 1643 bool DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1, \ 1644 DeviceMemory<double> *d2, DeviceMemory<double> *x1, \ 1645 const DeviceMemory<double> &y1, \ 1646 DeviceMemory<double> *param) override; \ 1647 bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha, \ 1648 DeviceMemory<float> *x, int incx) override; \ 1649 bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha, \ 1650 DeviceMemory<double> *x, int incx) override; \ 1651 bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha, \ 1652 DeviceMemory<std::complex<float>> *x, int incx) override; \ 1653 bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha, \ 1654 DeviceMemory<std::complex<double>> *x, int incx) override; \ 1655 bool DoBlasScal(Stream *stream, uint64 elem_count, \ 1656 std::complex<float> alpha, \ 1657 DeviceMemory<std::complex<float>> *x, int incx) override; \ 1658 bool DoBlasScal(Stream *stream, uint64 elem_count, \ 1659 std::complex<double> alpha, \ 1660 DeviceMemory<std::complex<double>> *x, int incx) override; \ 1661 bool DoBlasSwap(Stream *stream, uint64 elem_count, DeviceMemory<float> *x, \ 1662 int incx, DeviceMemory<float> *y, int incy) override; \ 1663 bool DoBlasSwap(Stream *stream, uint64 elem_count, DeviceMemory<double> *x, \ 1664 int incx, DeviceMemory<double> *y, int incy) override; \ 1665 bool DoBlasSwap(Stream *stream, uint64 elem_count, \ 1666 DeviceMemory<std::complex<float>> *x, int incx, \ 1667 DeviceMemory<std::complex<float>> *y, int incy) override; \ 1668 bool DoBlasSwap(Stream *stream, uint64 elem_count, \ 1669 DeviceMemory<std::complex<double>> *x, int incx, \ 1670 DeviceMemory<std::complex<double>> *y, int incy) override; \ 1671 bool DoBlasIamax(Stream *stream, uint64 elem_count, \ 1672 const DeviceMemory<float> &x, int incx, \ 1673 DeviceMemory<int> *result) override; \ 1674 bool DoBlasIamax(Stream *stream, uint64 elem_count, \ 1675 const DeviceMemory<double> &x, int incx, \ 1676 DeviceMemory<int> *result) override; \ 1677 bool DoBlasIamax(Stream *stream, uint64 elem_count, \ 1678 const DeviceMemory<std::complex<float>> &x, int incx, \ 1679 DeviceMemory<int> *result) override; \ 1680 bool DoBlasIamax(Stream *stream, uint64 elem_count, \ 1681 const DeviceMemory<std::complex<double>> &x, int incx, \ 1682 DeviceMemory<int> *result) override; \ 1683 bool DoBlasIamin(Stream *stream, uint64 elem_count, \ 1684 const DeviceMemory<float> &x, int incx, \ 1685 DeviceMemory<int> *result) override; \ 1686 bool DoBlasIamin(Stream *stream, uint64 elem_count, \ 1687 const DeviceMemory<double> &x, int incx, \ 1688 DeviceMemory<int> *result) override; \ 1689 bool DoBlasIamin(Stream *stream, uint64 elem_count, \ 1690 const DeviceMemory<std::complex<float>> &x, int incx, \ 1691 DeviceMemory<int> *result) override; \ 1692 bool DoBlasIamin(Stream *stream, uint64 elem_count, \ 1693 const DeviceMemory<std::complex<double>> &x, int incx, \ 1694 DeviceMemory<int> *result) override; \ 1695 bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1696 uint64 kl, uint64 ku, float alpha, \ 1697 const DeviceMemory<float> &a, int lda, \ 1698 const DeviceMemory<float> &x, int incx, float beta, \ 1699 DeviceMemory<float> *y, int incy) override; \ 1700 bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1701 uint64 kl, uint64 ku, double alpha, \ 1702 const DeviceMemory<double> &a, int lda, \ 1703 const DeviceMemory<double> &x, int incx, double beta, \ 1704 DeviceMemory<double> *y, int incy) override; \ 1705 bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1706 uint64 kl, uint64 ku, std::complex<float> alpha, \ 1707 const DeviceMemory<std::complex<float>> &a, int lda, \ 1708 const DeviceMemory<std::complex<float>> &x, int incx, \ 1709 std::complex<float> beta, \ 1710 DeviceMemory<std::complex<float>> *y, int incy) override; \ 1711 bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1712 uint64 kl, uint64 ku, std::complex<double> alpha, \ 1713 const DeviceMemory<std::complex<double>> &a, int lda, \ 1714 const DeviceMemory<std::complex<double>> &x, int incx, \ 1715 std::complex<double> beta, \ 1716 DeviceMemory<std::complex<double>> *y, int incy) override; \ 1717 bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1718 float alpha, const DeviceMemory<float> &a, int lda, \ 1719 const DeviceMemory<float> &x, int incx, float beta, \ 1720 DeviceMemory<float> *y, int incy) override; \ 1721 bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1722 double alpha, const DeviceMemory<double> &a, int lda, \ 1723 const DeviceMemory<double> &x, int incx, double beta, \ 1724 DeviceMemory<double> *y, int incy) override; \ 1725 bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1726 std::complex<float> alpha, \ 1727 const DeviceMemory<std::complex<float>> &a, int lda, \ 1728 const DeviceMemory<std::complex<float>> &x, int incx, \ 1729 std::complex<float> beta, \ 1730 DeviceMemory<std::complex<float>> *y, int incy) override; \ 1731 bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1732 std::complex<double> alpha, \ 1733 const DeviceMemory<std::complex<double>> &a, int lda, \ 1734 const DeviceMemory<std::complex<double>> &x, int incx, \ 1735 std::complex<double> beta, \ 1736 DeviceMemory<std::complex<double>> *y, int incy) override; \ 1737 bool DoBlasGemvWithProfiling( \ 1738 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha, \ 1739 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x, \ 1740 int incx, float beta, DeviceMemory<float> *y, int incy, \ 1741 blas::ProfileResult *output_profile_result) override; \ 1742 bool DoBlasGemvWithProfiling( \ 1743 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha, \ 1744 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x, \ 1745 int incx, double beta, DeviceMemory<double> *y, int incy, \ 1746 blas::ProfileResult *output_profile_result) override; \ 1747 bool DoBlasGemvWithProfiling( \ 1748 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1749 std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a, \ 1750 int lda, const DeviceMemory<std::complex<float>> &x, int incx, \ 1751 std::complex<float> beta, DeviceMemory<std::complex<float>> *y, \ 1752 int incy, blas::ProfileResult *output_profile_result) override; \ 1753 bool DoBlasGemvWithProfiling( \ 1754 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1755 std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a, \ 1756 int lda, const DeviceMemory<std::complex<double>> &x, int incx, \ 1757 std::complex<double> beta, DeviceMemory<std::complex<double>> *y, \ 1758 int incy, blas::ProfileResult *output_profile_result) override; \ 1759 bool DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha, \ 1760 const DeviceMemory<float> &x, int incx, \ 1761 const DeviceMemory<float> &y, int incy, \ 1762 DeviceMemory<float> *a, int lda) override; \ 1763 bool DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha, \ 1764 const DeviceMemory<double> &x, int incx, \ 1765 const DeviceMemory<double> &y, int incy, \ 1766 DeviceMemory<double> *a, int lda) override; \ 1767 bool DoBlasGerc(Stream *stream, uint64 m, uint64 n, \ 1768 std::complex<float> alpha, \ 1769 const DeviceMemory<std::complex<float>> &x, int incx, \ 1770 const DeviceMemory<std::complex<float>> &y, int incy, \ 1771 DeviceMemory<std::complex<float>> *a, int lda) override; \ 1772 bool DoBlasGerc(Stream *stream, uint64 m, uint64 n, \ 1773 std::complex<double> alpha, \ 1774 const DeviceMemory<std::complex<double>> &x, int incx, \ 1775 const DeviceMemory<std::complex<double>> &y, int incy, \ 1776 DeviceMemory<std::complex<double>> *a, int lda) override; \ 1777 bool DoBlasGeru(Stream *stream, uint64 m, uint64 n, \ 1778 std::complex<float> alpha, \ 1779 const DeviceMemory<std::complex<float>> &x, int incx, \ 1780 const DeviceMemory<std::complex<float>> &y, int incy, \ 1781 DeviceMemory<std::complex<float>> *a, int lda) override; \ 1782 bool DoBlasGeru(Stream *stream, uint64 m, uint64 n, \ 1783 std::complex<double> alpha, \ 1784 const DeviceMemory<std::complex<double>> &x, int incx, \ 1785 const DeviceMemory<std::complex<double>> &y, int incy, \ 1786 DeviceMemory<std::complex<double>> *a, int lda) override; \ 1787 bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k, \ 1788 std::complex<float> alpha, \ 1789 const DeviceMemory<std::complex<float>> &a, int lda, \ 1790 const DeviceMemory<std::complex<float>> &x, int incx, \ 1791 std::complex<float> beta, \ 1792 DeviceMemory<std::complex<float>> *y, int incy) override; \ 1793 bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k, \ 1794 std::complex<double> alpha, \ 1795 const DeviceMemory<std::complex<double>> &a, int lda, \ 1796 const DeviceMemory<std::complex<double>> &x, int incx, \ 1797 std::complex<double> beta, \ 1798 DeviceMemory<std::complex<double>> *y, int incy) override; \ 1799 bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1800 std::complex<float> alpha, \ 1801 const DeviceMemory<std::complex<float>> &a, int lda, \ 1802 const DeviceMemory<std::complex<float>> &x, int incx, \ 1803 std::complex<float> beta, \ 1804 DeviceMemory<std::complex<float>> *y, int incy) override; \ 1805 bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1806 std::complex<double> alpha, \ 1807 const DeviceMemory<std::complex<double>> &a, int lda, \ 1808 const DeviceMemory<std::complex<double>> &x, int incx, \ 1809 std::complex<double> beta, \ 1810 DeviceMemory<std::complex<double>> *y, int incy) override; \ 1811 bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \ 1812 const DeviceMemory<std::complex<float>> &x, int incx, \ 1813 DeviceMemory<std::complex<float>> *a, int lda) override; \ 1814 bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1815 double alpha, const DeviceMemory<std::complex<double>> &x, \ 1816 int incx, DeviceMemory<std::complex<double>> *a, int lda) \ 1817 override; \ 1818 bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1819 std::complex<float> alpha, \ 1820 const DeviceMemory<std::complex<float>> &x, int incx, \ 1821 const DeviceMemory<std::complex<float>> &y, int incy, \ 1822 DeviceMemory<std::complex<float>> *a, int lda) override; \ 1823 bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1824 std::complex<double> alpha, \ 1825 const DeviceMemory<std::complex<double>> &x, int incx, \ 1826 const DeviceMemory<std::complex<double>> &y, int incy, \ 1827 DeviceMemory<std::complex<double>> *a, int lda) override; \ 1828 bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1829 std::complex<float> alpha, \ 1830 const DeviceMemory<std::complex<float>> &ap, \ 1831 const DeviceMemory<std::complex<float>> &x, int incx, \ 1832 std::complex<float> beta, \ 1833 DeviceMemory<std::complex<float>> *y, int incy) override; \ 1834 bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1835 std::complex<double> alpha, \ 1836 const DeviceMemory<std::complex<double>> &ap, \ 1837 const DeviceMemory<std::complex<double>> &x, int incx, \ 1838 std::complex<double> beta, \ 1839 DeviceMemory<std::complex<double>> *y, int incy) override; \ 1840 bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \ 1841 const DeviceMemory<std::complex<float>> &x, int incx, \ 1842 DeviceMemory<std::complex<float>> *ap) override; \ 1843 bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1844 double alpha, const DeviceMemory<std::complex<double>> &x, \ 1845 int incx, DeviceMemory<std::complex<double>> *ap) override; \ 1846 bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1847 std::complex<float> alpha, \ 1848 const DeviceMemory<std::complex<float>> &x, int incx, \ 1849 const DeviceMemory<std::complex<float>> &y, int incy, \ 1850 DeviceMemory<std::complex<float>> *ap) override; \ 1851 bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1852 std::complex<double> alpha, \ 1853 const DeviceMemory<std::complex<double>> &x, int incx, \ 1854 const DeviceMemory<std::complex<double>> &y, int incy, \ 1855 DeviceMemory<std::complex<double>> *ap) override; \ 1856 bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k, \ 1857 float alpha, const DeviceMemory<float> &a, int lda, \ 1858 const DeviceMemory<float> &x, int incx, float beta, \ 1859 DeviceMemory<float> *y, int incy) override; \ 1860 bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k, \ 1861 double alpha, const DeviceMemory<double> &a, int lda, \ 1862 const DeviceMemory<double> &x, int incx, double beta, \ 1863 DeviceMemory<double> *y, int incy) override; \ 1864 bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1865 float alpha, const DeviceMemory<float> &ap, \ 1866 const DeviceMemory<float> &x, int incx, float beta, \ 1867 DeviceMemory<float> *y, int incy) override; \ 1868 bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1869 double alpha, const DeviceMemory<double> &ap, \ 1870 const DeviceMemory<double> &x, int incx, double beta, \ 1871 DeviceMemory<double> *y, int incy) override; \ 1872 bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \ 1873 const DeviceMemory<float> &x, int incx, \ 1874 DeviceMemory<float> *ap) override; \ 1875 bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1876 double alpha, const DeviceMemory<double> &x, int incx, \ 1877 DeviceMemory<double> *ap) override; \ 1878 bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1879 float alpha, const DeviceMemory<float> &x, int incx, \ 1880 const DeviceMemory<float> &y, int incy, \ 1881 DeviceMemory<float> *ap) override; \ 1882 bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1883 double alpha, const DeviceMemory<double> &x, int incx, \ 1884 const DeviceMemory<double> &y, int incy, \ 1885 DeviceMemory<double> *ap) override; \ 1886 bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1887 float alpha, const DeviceMemory<float> &a, int lda, \ 1888 const DeviceMemory<float> &x, int incx, float beta, \ 1889 DeviceMemory<float> *y, int incy) override; \ 1890 bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1891 double alpha, const DeviceMemory<double> &a, int lda, \ 1892 const DeviceMemory<double> &x, int incx, double beta, \ 1893 DeviceMemory<double> *y, int incy) override; \ 1894 bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \ 1895 const DeviceMemory<float> &x, int incx, \ 1896 DeviceMemory<float> *a, int lda) override; \ 1897 bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1898 double alpha, const DeviceMemory<double> &x, int incx, \ 1899 DeviceMemory<double> *a, int lda) override; \ 1900 bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1901 float alpha, const DeviceMemory<float> &x, int incx, \ 1902 const DeviceMemory<float> &y, int incy, \ 1903 DeviceMemory<float> *a, int lda) override; \ 1904 bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1905 double alpha, const DeviceMemory<double> &x, int incx, \ 1906 const DeviceMemory<double> &y, int incy, \ 1907 DeviceMemory<double> *a, int lda) override; \ 1908 bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, \ 1909 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1910 uint64 k, const DeviceMemory<float> &a, int lda, \ 1911 DeviceMemory<float> *x, int incx) override; \ 1912 bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, \ 1913 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1914 uint64 k, const DeviceMemory<double> &a, int lda, \ 1915 DeviceMemory<double> *x, int incx) override; \ 1916 bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, \ 1917 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1918 uint64 k, const DeviceMemory<std::complex<float>> &a, \ 1919 int lda, DeviceMemory<std::complex<float>> *x, int incx) \ 1920 override; \ 1921 bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, \ 1922 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1923 uint64 k, const DeviceMemory<std::complex<double>> &a, \ 1924 int lda, DeviceMemory<std::complex<double>> *x, int incx) \ 1925 override; \ 1926 bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, \ 1927 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1928 uint64 k, const DeviceMemory<float> &a, int lda, \ 1929 DeviceMemory<float> *x, int incx) override; \ 1930 bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, \ 1931 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1932 uint64 k, const DeviceMemory<double> &a, int lda, \ 1933 DeviceMemory<double> *x, int incx) override; \ 1934 bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, \ 1935 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1936 uint64 k, const DeviceMemory<std::complex<float>> &a, \ 1937 int lda, DeviceMemory<std::complex<float>> *x, int incx) \ 1938 override; \ 1939 bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, \ 1940 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1941 uint64 k, const DeviceMemory<std::complex<double>> &a, \ 1942 int lda, DeviceMemory<std::complex<double>> *x, int incx) \ 1943 override; \ 1944 bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, \ 1945 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1946 const DeviceMemory<float> &ap, DeviceMemory<float> *x, \ 1947 int incx) override; \ 1948 bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, \ 1949 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1950 const DeviceMemory<double> &ap, DeviceMemory<double> *x, \ 1951 int incx) override; \ 1952 bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, \ 1953 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1954 const DeviceMemory<std::complex<float>> &ap, \ 1955 DeviceMemory<std::complex<float>> *x, int incx) override; \ 1956 bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, \ 1957 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1958 const DeviceMemory<std::complex<double>> &ap, \ 1959 DeviceMemory<std::complex<double>> *x, int incx) override; \ 1960 bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, \ 1961 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1962 const DeviceMemory<float> &ap, DeviceMemory<float> *x, \ 1963 int incx) override; \ 1964 bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, \ 1965 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1966 const DeviceMemory<double> &ap, DeviceMemory<double> *x, \ 1967 int incx) override; \ 1968 bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, \ 1969 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1970 const DeviceMemory<std::complex<float>> &ap, \ 1971 DeviceMemory<std::complex<float>> *x, int incx) override; \ 1972 bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, \ 1973 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1974 const DeviceMemory<std::complex<double>> &ap, \ 1975 DeviceMemory<std::complex<double>> *x, int incx) override; \ 1976 bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, \ 1977 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1978 const DeviceMemory<float> &a, int lda, \ 1979 DeviceMemory<float> *x, int incx) override; \ 1980 bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, \ 1981 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1982 const DeviceMemory<double> &a, int lda, \ 1983 DeviceMemory<double> *x, int incx) override; \ 1984 bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, \ 1985 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1986 const DeviceMemory<std::complex<float>> &a, int lda, \ 1987 DeviceMemory<std::complex<float>> *x, int incx) override; \ 1988 bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, \ 1989 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1990 const DeviceMemory<std::complex<double>> &a, int lda, \ 1991 DeviceMemory<std::complex<double>> *x, int incx) override; \ 1992 bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, \ 1993 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1994 const DeviceMemory<float> &a, int lda, \ 1995 DeviceMemory<float> *x, int incx) override; \ 1996 bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, \ 1997 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1998 const DeviceMemory<double> &a, int lda, \ 1999 DeviceMemory<double> *x, int incx) override; \ 2000 bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, \ 2001 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 2002 const DeviceMemory<std::complex<float>> &a, int lda, \ 2003 DeviceMemory<std::complex<float>> *x, int incx) override; \ 2004 bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, \ 2005 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 2006 const DeviceMemory<std::complex<double>> &a, int lda, \ 2007 DeviceMemory<std::complex<double>> *x, int incx) override; \ 2008 bool DoBlasGemm(Stream *stream, blas::Transpose transa, \ 2009 blas::Transpose transb, uint64 m, uint64 n, uint64 k, \ 2010 float alpha, const DeviceMemory<Eigen::half> &a, int lda, \ 2011 const DeviceMemory<Eigen::half> &b, int ldb, float beta, \ 2012 DeviceMemory<Eigen::half> *c, int ldc) override; \ 2013 bool DoBlasGemm(Stream *stream, blas::Transpose transa, \ 2014 blas::Transpose transb, uint64 m, uint64 n, uint64 k, \ 2015 float alpha, const DeviceMemory<float> &a, int lda, \ 2016 const DeviceMemory<float> &b, int ldb, float beta, \ 2017 DeviceMemory<float> *c, int ldc) override; \ 2018 bool DoBlasGemm(Stream *stream, blas::Transpose transa, \ 2019 blas::Transpose transb, uint64 m, uint64 n, uint64 k, \ 2020 double alpha, const DeviceMemory<double> &a, int lda, \ 2021 const DeviceMemory<double> &b, int ldb, double beta, \ 2022 DeviceMemory<double> *c, int ldc) override; \ 2023 bool DoBlasGemm(Stream *stream, blas::Transpose transa, \ 2024 blas::Transpose transb, uint64 m, uint64 n, uint64 k, \ 2025 std::complex<float> alpha, \ 2026 const DeviceMemory<std::complex<float>> &a, int lda, \ 2027 const DeviceMemory<std::complex<float>> &b, int ldb, \ 2028 std::complex<float> beta, \ 2029 DeviceMemory<std::complex<float>> *c, int ldc) override; \ 2030 bool DoBlasGemm(Stream *stream, blas::Transpose transa, \ 2031 blas::Transpose transb, uint64 m, uint64 n, uint64 k, \ 2032 std::complex<double> alpha, \ 2033 const DeviceMemory<std::complex<double>> &a, int lda, \ 2034 const DeviceMemory<std::complex<double>> &b, int ldb, \ 2035 std::complex<double> beta, \ 2036 DeviceMemory<std::complex<double>> *c, int ldc) override; \ 2037 bool DoBlasGemmWithProfiling( \ 2038 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2039 uint64 m, uint64 n, uint64 k, float alpha, \ 2040 const DeviceMemory<Eigen::half> &a, int lda, \ 2041 const DeviceMemory<Eigen::half> &b, int ldb, float beta, \ 2042 DeviceMemory<Eigen::half> *c, int ldc, \ 2043 blas::ProfileResult *output_profile_result) override; \ 2044 bool DoBlasGemmWithProfiling( \ 2045 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2046 uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, \ 2047 int lda, const DeviceMemory<float> &b, int ldb, float beta, \ 2048 DeviceMemory<float> *c, int ldc, \ 2049 blas::ProfileResult *output_profile_result) override; \ 2050 bool DoBlasGemmWithProfiling( \ 2051 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2052 uint64 m, uint64 n, uint64 k, double alpha, \ 2053 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b, \ 2054 int ldb, double beta, DeviceMemory<double> *c, int ldc, \ 2055 blas::ProfileResult *output_profile_result) override; \ 2056 bool DoBlasGemmWithProfiling( \ 2057 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2058 uint64 m, uint64 n, uint64 k, std::complex<float> alpha, \ 2059 const DeviceMemory<std::complex<float>> &a, int lda, \ 2060 const DeviceMemory<std::complex<float>> &b, int ldb, \ 2061 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, \ 2062 blas::ProfileResult *output_profile_result) override; \ 2063 bool DoBlasGemmWithProfiling( \ 2064 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2065 uint64 m, uint64 n, uint64 k, std::complex<double> alpha, \ 2066 const DeviceMemory<std::complex<double>> &a, int lda, \ 2067 const DeviceMemory<std::complex<double>> &b, int ldb, \ 2068 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, \ 2069 int ldc, blas::ProfileResult *output_profile_result) override; \ 2070 bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms) \ 2071 override; \ 2072 bool DoBlasGemmWithAlgorithm( \ 2073 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2074 uint64 m, uint64 n, uint64 k, const HostOrDeviceScalar<int> &alpha, \ 2075 const DeviceMemory<int8> &a, int lda, const DeviceMemory<int8> &b, \ 2076 int ldb, const HostOrDeviceScalar<int> &beta, DeviceMemory<int> *c, \ 2077 int ldc, blas::ComputationType computation_type, \ 2078 blas::AlgorithmType algorithm, \ 2079 blas::ProfileResult *output_profile_result) override; \ 2080 bool DoBlasGemmWithAlgorithm( \ 2081 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2082 uint64 m, uint64 n, uint64 k, \ 2083 const HostOrDeviceScalar<Eigen::half> &alpha, \ 2084 const DeviceMemory<Eigen::half> &a, int lda, \ 2085 const DeviceMemory<Eigen::half> &b, int ldb, \ 2086 const HostOrDeviceScalar<Eigen::half> &beta, \ 2087 DeviceMemory<Eigen::half> *c, int ldc, \ 2088 blas::ComputationType computation_type, blas::AlgorithmType algorithm, \ 2089 blas::ProfileResult *output_profile_result) override; \ 2090 bool DoBlasGemmWithAlgorithm( \ 2091 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2092 uint64 m, uint64 n, uint64 k, const HostOrDeviceScalar<float> &alpha, \ 2093 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b, \ 2094 int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c, \ 2095 int ldc, blas::ComputationType computation_type, \ 2096 blas::AlgorithmType algorithm, \ 2097 blas::ProfileResult *output_profile_result) override; \ 2098 bool DoBlasGemmWithAlgorithm( \ 2099 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2100 uint64 m, uint64 n, uint64 k, const HostOrDeviceScalar<double> &alpha, \ 2101 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b, \ 2102 int ldb, const HostOrDeviceScalar<double> &beta, \ 2103 DeviceMemory<double> *c, int ldc, \ 2104 blas::ComputationType computation_type, blas::AlgorithmType algorithm, \ 2105 blas::ProfileResult *output_profile_result) override; \ 2106 bool DoBlasGemmWithAlgorithm( \ 2107 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2108 uint64 m, uint64 n, uint64 k, \ 2109 const HostOrDeviceScalar<std::complex<float>> &alpha, \ 2110 const DeviceMemory<std::complex<float>> &a, int lda, \ 2111 const DeviceMemory<std::complex<float>> &b, int ldb, \ 2112 const HostOrDeviceScalar<std::complex<float>> &beta, \ 2113 DeviceMemory<std::complex<float>> *c, int ldc, \ 2114 blas::ComputationType computation_type, blas::AlgorithmType algorithm, \ 2115 blas::ProfileResult *output_profile_result) override; \ 2116 bool DoBlasGemmWithAlgorithm( \ 2117 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2118 uint64 m, uint64 n, uint64 k, \ 2119 const HostOrDeviceScalar<std::complex<double>> &alpha, \ 2120 const DeviceMemory<std::complex<double>> &a, int lda, \ 2121 const DeviceMemory<std::complex<double>> &b, int ldb, \ 2122 const HostOrDeviceScalar<std::complex<double>> &beta, \ 2123 DeviceMemory<std::complex<double>> *c, int ldc, \ 2124 blas::ComputationType computation_type, blas::AlgorithmType algorithm, \ 2125 blas::ProfileResult *output_profile_result) override; \ 2126 bool DoBlasGemmBatched( \ 2127 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2128 uint64 m, uint64 n, uint64 k, float alpha, \ 2129 const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda, \ 2130 const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, \ 2131 float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, \ 2132 int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \ 2133 bool DoBlasGemmBatched( \ 2134 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2135 uint64 m, uint64 n, uint64 k, float alpha, \ 2136 const port::ArraySlice<DeviceMemory<float> *> &a, int lda, \ 2137 const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, float beta, \ 2138 const port::ArraySlice<DeviceMemory<float> *> &c, int ldc, \ 2139 int batch_count, ScratchAllocator *scratch_allocator) override; \ 2140 bool DoBlasGemmBatched( \ 2141 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2142 uint64 m, uint64 n, uint64 k, double alpha, \ 2143 const port::ArraySlice<DeviceMemory<double> *> &a, int lda, \ 2144 const port::ArraySlice<DeviceMemory<double> *> &b, int ldb, double beta, \ 2145 const port::ArraySlice<DeviceMemory<double> *> &c, int ldc, \ 2146 int batch_count, ScratchAllocator *scratch_allocator) override; \ 2147 bool DoBlasGemmBatched( \ 2148 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2149 uint64 m, uint64 n, uint64 k, std::complex<float> alpha, \ 2150 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda, \ 2151 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb, \ 2152 std::complex<float> beta, \ 2153 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc, \ 2154 int batch_count, ScratchAllocator *scratch_allocator) override; \ 2155 bool DoBlasGemmBatched( \ 2156 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2157 uint64 m, uint64 n, uint64 k, std::complex<double> alpha, \ 2158 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, \ 2159 int lda, \ 2160 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, \ 2161 int ldb, std::complex<double> beta, \ 2162 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, \ 2163 int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \ 2164 bool DoBlasGemmStridedBatched( \ 2165 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2166 uint64 m, uint64 n, uint64 k, float alpha, \ 2167 const DeviceMemory<Eigen::half> &a, int lda, int64 stride_a, \ 2168 const DeviceMemory<Eigen::half> &b, int ldb, int64 stride_b, float beta, \ 2169 DeviceMemory<Eigen::half> *c, int ldc, int64 stride_c, int batch_count); \ 2170 bool DoBlasGemmStridedBatched( \ 2171 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2172 uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, \ 2173 int lda, int64 stride_a, const DeviceMemory<float> &b, int ldb, \ 2174 int64 stride_b, float beta, DeviceMemory<float> *c, int ldc, \ 2175 int64 stride_c, int batch_count); \ 2176 bool DoBlasGemmStridedBatched( \ 2177 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2178 uint64 m, uint64 n, uint64 k, double alpha, \ 2179 const DeviceMemory<double> &a, int lda, int64 stride_a, \ 2180 const DeviceMemory<double> &b, int ldb, int64 stride_b, double beta, \ 2181 DeviceMemory<double> *c, int ldc, int64 stride_c, int batch_count); \ 2182 bool DoBlasGemmStridedBatched( \ 2183 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2184 uint64 m, uint64 n, uint64 k, std::complex<float> alpha, \ 2185 const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a, \ 2186 const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b, \ 2187 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, \ 2188 int64 stride_c, int batch_count); \ 2189 bool DoBlasGemmStridedBatched( \ 2190 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2191 uint64 m, uint64 n, uint64 k, std::complex<double> alpha, \ 2192 const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a, \ 2193 const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b, \ 2194 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, \ 2195 int ldc, int64 stride_c, int batch_count); \ 2196 bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2197 uint64 m, uint64 n, std::complex<float> alpha, \ 2198 const DeviceMemory<std::complex<float>> &a, int lda, \ 2199 const DeviceMemory<std::complex<float>> &b, int ldb, \ 2200 std::complex<float> beta, \ 2201 DeviceMemory<std::complex<float>> *c, int ldc) override; \ 2202 bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2203 uint64 m, uint64 n, std::complex<double> alpha, \ 2204 const DeviceMemory<std::complex<double>> &a, int lda, \ 2205 const DeviceMemory<std::complex<double>> &b, int ldb, \ 2206 std::complex<double> beta, \ 2207 DeviceMemory<std::complex<double>> *c, int ldc) override; \ 2208 bool DoBlasHerk(Stream *stream, blas::UpperLower uplo, \ 2209 blas::Transpose trans, uint64 n, uint64 k, float alpha, \ 2210 const DeviceMemory<std::complex<float>> &a, int lda, \ 2211 float beta, DeviceMemory<std::complex<float>> *c, int ldc) \ 2212 override; \ 2213 bool DoBlasHerk(Stream *stream, blas::UpperLower uplo, \ 2214 blas::Transpose trans, uint64 n, uint64 k, double alpha, \ 2215 const DeviceMemory<std::complex<double>> &a, int lda, \ 2216 double beta, DeviceMemory<std::complex<double>> *c, int ldc) \ 2217 override; \ 2218 bool DoBlasHer2k( \ 2219 Stream *stream, blas::UpperLower uplo, blas::Transpose trans, uint64 n, \ 2220 uint64 k, std::complex<float> alpha, \ 2221 const DeviceMemory<std::complex<float>> &a, int lda, \ 2222 const DeviceMemory<std::complex<float>> &b, int ldb, float beta, \ 2223 DeviceMemory<std::complex<float>> *c, int ldc) override; \ 2224 bool DoBlasHer2k( \ 2225 Stream *stream, blas::UpperLower uplo, blas::Transpose trans, uint64 n, \ 2226 uint64 k, std::complex<double> alpha, \ 2227 const DeviceMemory<std::complex<double>> &a, int lda, \ 2228 const DeviceMemory<std::complex<double>> &b, int ldb, double beta, \ 2229 DeviceMemory<std::complex<double>> *c, int ldc) override; \ 2230 bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2231 uint64 m, uint64 n, float alpha, \ 2232 const DeviceMemory<float> &a, int lda, \ 2233 const DeviceMemory<float> &b, int ldb, float beta, \ 2234 DeviceMemory<float> *c, int ldc) override; \ 2235 bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2236 uint64 m, uint64 n, double alpha, \ 2237 const DeviceMemory<double> &a, int lda, \ 2238 const DeviceMemory<double> &b, int ldb, double beta, \ 2239 DeviceMemory<double> *c, int ldc) override; \ 2240 bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2241 uint64 m, uint64 n, std::complex<float> alpha, \ 2242 const DeviceMemory<std::complex<float>> &a, int lda, \ 2243 const DeviceMemory<std::complex<float>> &b, int ldb, \ 2244 std::complex<float> beta, \ 2245 DeviceMemory<std::complex<float>> *c, int ldc) override; \ 2246 bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2247 uint64 m, uint64 n, std::complex<double> alpha, \ 2248 const DeviceMemory<std::complex<double>> &a, int lda, \ 2249 const DeviceMemory<std::complex<double>> &b, int ldb, \ 2250 std::complex<double> beta, \ 2251 DeviceMemory<std::complex<double>> *c, int ldc) override; \ 2252 bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, \ 2253 blas::Transpose trans, uint64 n, uint64 k, float alpha, \ 2254 const DeviceMemory<float> &a, int lda, float beta, \ 2255 DeviceMemory<float> *c, int ldc) override; \ 2256 bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, \ 2257 blas::Transpose trans, uint64 n, uint64 k, double alpha, \ 2258 const DeviceMemory<double> &a, int lda, double beta, \ 2259 DeviceMemory<double> *c, int ldc) override; \ 2260 bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, \ 2261 blas::Transpose trans, uint64 n, uint64 k, \ 2262 std::complex<float> alpha, \ 2263 const DeviceMemory<std::complex<float>> &a, int lda, \ 2264 std::complex<float> beta, \ 2265 DeviceMemory<std::complex<float>> *c, int ldc) override; \ 2266 bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, \ 2267 blas::Transpose trans, uint64 n, uint64 k, \ 2268 std::complex<double> alpha, \ 2269 const DeviceMemory<std::complex<double>> &a, int lda, \ 2270 std::complex<double> beta, \ 2271 DeviceMemory<std::complex<double>> *c, int ldc) override; \ 2272 bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, \ 2273 blas::Transpose trans, uint64 n, uint64 k, float alpha, \ 2274 const DeviceMemory<float> &a, int lda, \ 2275 const DeviceMemory<float> &b, int ldb, float beta, \ 2276 DeviceMemory<float> *c, int ldc) override; \ 2277 bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, \ 2278 blas::Transpose trans, uint64 n, uint64 k, double alpha, \ 2279 const DeviceMemory<double> &a, int lda, \ 2280 const DeviceMemory<double> &b, int ldb, double beta, \ 2281 DeviceMemory<double> *c, int ldc) override; \ 2282 bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, \ 2283 blas::Transpose trans, uint64 n, uint64 k, \ 2284 std::complex<float> alpha, \ 2285 const DeviceMemory<std::complex<float>> &a, int lda, \ 2286 const DeviceMemory<std::complex<float>> &b, int ldb, \ 2287 std::complex<float> beta, \ 2288 DeviceMemory<std::complex<float>> *c, int ldc) override; \ 2289 bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, \ 2290 blas::Transpose trans, uint64 n, uint64 k, \ 2291 std::complex<double> alpha, \ 2292 const DeviceMemory<std::complex<double>> &a, int lda, \ 2293 const DeviceMemory<std::complex<double>> &b, int ldb, \ 2294 std::complex<double> beta, \ 2295 DeviceMemory<std::complex<double>> *c, int ldc) override; \ 2296 bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2297 blas::Transpose transa, blas::Diagonal diag, uint64 m, \ 2298 uint64 n, float alpha, const DeviceMemory<float> &a, \ 2299 int lda, DeviceMemory<float> *b, int ldb) override; \ 2300 bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2301 blas::Transpose transa, blas::Diagonal diag, uint64 m, \ 2302 uint64 n, double alpha, const DeviceMemory<double> &a, \ 2303 int lda, DeviceMemory<double> *b, int ldb) override; \ 2304 bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2305 blas::Transpose transa, blas::Diagonal diag, uint64 m, \ 2306 uint64 n, std::complex<float> alpha, \ 2307 const DeviceMemory<std::complex<float>> &a, int lda, \ 2308 DeviceMemory<std::complex<float>> *b, int ldb) override; \ 2309 bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2310 blas::Transpose transa, blas::Diagonal diag, uint64 m, \ 2311 uint64 n, std::complex<double> alpha, \ 2312 const DeviceMemory<std::complex<double>> &a, int lda, \ 2313 DeviceMemory<std::complex<double>> *b, int ldb) override; \ 2314 bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2315 blas::Transpose transa, blas::Diagonal diag, uint64 m, \ 2316 uint64 n, float alpha, const DeviceMemory<float> &a, \ 2317 int lda, DeviceMemory<float> *b, int ldb) override; \ 2318 bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2319 blas::Transpose transa, blas::Diagonal diag, uint64 m, \ 2320 uint64 n, double alpha, const DeviceMemory<double> &a, \ 2321 int lda, DeviceMemory<double> *b, int ldb) override; \ 2322 bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2323 blas::Transpose transa, blas::Diagonal diag, uint64 m, \ 2324 uint64 n, std::complex<float> alpha, \ 2325 const DeviceMemory<std::complex<float>> &a, int lda, \ 2326 DeviceMemory<std::complex<float>> *b, int ldb) override; \ 2327 bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2328 blas::Transpose transa, blas::Diagonal diag, uint64 m, \ 2329 uint64 n, std::complex<double> alpha, \ 2330 const DeviceMemory<std::complex<double>> &a, int lda, \ 2331 DeviceMemory<std::complex<double>> *b, int ldb) override; \ 2332 port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>> \ 2333 CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams ¶ms) override; \ 2334 port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>> \ 2335 GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan, \ 2336 size_t max_workspace_size, \ 2337 int max_algorithm_count) override; \ 2338 bool DoBlasLtMatmul( \ 2339 Stream *stream, const blas::IBlasLtMatmulPlan *plan, \ 2340 const HostOrDeviceScalar<void> &alpha, DeviceMemoryBase a, \ 2341 DeviceMemoryBase b, const HostOrDeviceScalar<void> &beta, \ 2342 DeviceMemoryBase c, ScratchAllocator *scratch_allocator, \ 2343 const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias, \ 2344 blas::ProfileResult *output_profile_result) override; \ 2345 port::Status GetVersion(std::string *version) override; 2346 2347 } // namespace blas 2348 } // namespace stream_executor 2349 2350 #endif // TENSORFLOW_STREAM_EXECUTOR_BLAS_H_ 2351