1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Exposes the family of BLAS routines as pre-canned high performance calls for 17 // use in conjunction with the StreamExecutor abstraction. 18 // 19 // Note that this interface is optionally supported by platforms; see 20 // StreamExecutor::SupportsBlas() for details. 21 // 22 // This abstraction makes it simple to entrain BLAS operations on GPU data into 23 // a Stream -- users typically will not use this API directly, but will use the 24 // Stream builder methods to entrain these operations "under the hood". For 25 // example: 26 // 27 // DeviceMemory<float> x = stream_exec->AllocateArray<float>(1024); 28 // DeviceMemory<float> y = stream_exec->AllocateArray<float>(1024); 29 // // ... populate x and y ... 30 // Stream stream{stream_exec}; 31 // stream 32 // .Init() 33 // .ThenBlasAxpy(1024, 5.5, x, 1, &y, 1); 34 // SE_CHECK_OK(stream.BlockHostUntilDone()); 35 // 36 // By using stream operations in this manner the user can easily intermix custom 37 // kernel launches (via StreamExecutor::ThenLaunch()) with these pre-canned BLAS 38 // routines. 39 40 #ifndef TENSORFLOW_STREAM_EXECUTOR_BLAS_H_ 41 #define TENSORFLOW_STREAM_EXECUTOR_BLAS_H_ 42 43 #include <complex> 44 #include <vector> 45 46 #include "tensorflow/stream_executor/dnn.h" // For DataType, ToDataType 47 #include "tensorflow/stream_executor/lib/array_slice.h" 48 #include "tensorflow/stream_executor/lib/statusor.h" 49 #include "tensorflow/stream_executor/platform/port.h" 50 51 namespace Eigen { 52 struct half; 53 } // namespace Eigen 54 55 namespace stream_executor { 56 57 class Stream; 58 class ScratchAllocator; 59 60 template <typename ElemT> 61 class DeviceMemory; 62 63 template <typename ElemT> 64 class HostOrDeviceScalar; 65 66 namespace blas { 67 68 // Specifies whether the input matrix will be transposed or 69 // transposed+conjugated before any BLAS operations. 70 enum class Transpose { kNoTranspose, kTranspose, kConjugateTranspose }; 71 72 // Returns a name for t. 73 std::string TransposeString(Transpose t); 74 75 // Specifies whether the upper or lower triangular part of a 76 // symmetric/Hermitian matrix is used. 77 enum class UpperLower { kUpper, kLower }; 78 79 // Returns a name for ul. 80 std::string UpperLowerString(UpperLower ul); 81 82 // Specifies whether a matrix is unit triangular. 83 enum class Diagonal { kUnit, kNonUnit }; 84 85 // Returns a name for d. 86 std::string DiagonalString(Diagonal d); 87 88 // Specifies whether a Hermitian matrix appears on the left or right in 89 // operation. 90 enum class Side { kLeft, kRight }; 91 92 // Returns a name for s. 93 std::string SideString(Side s); 94 95 // Type with which intermediate computations of a blas routine are performed. 96 // 97 // Some blas calls can perform computations with a type that's different than 98 // the type of their inputs/outputs. This lets you e.g. multiply two matrices 99 // of int8s using float32s to store the matmul's intermediate values. 100 enum class ComputationType { 101 kF16, // 16-bit floating-point 102 kF32, // 32-bit floating-point 103 kF64, // 64-bit floating-point 104 kI32, // 32-bit integer 105 kComplexF32, // Complex number comprised of two f32s. 106 kComplexF64, // Complex number comprised of two f64s. 107 // The below values are only supported for BlasLt routines (both real and 108 // complex). They use float32 for accumulation but round the input mantissas 109 // to a smaller number of bits. 110 kTF32AsF32, // 32-bit floating-point with reduced (>=10-bit) mantissa 111 kBF16AsF32, // 32-bit floating-point with reduced (7-bit) mantissa 112 }; 113 114 enum class Epilogue { 115 kDefault = 1, // No special postprocessing 116 kReLU = 2, // Apply ReLU func point-wise to the results 117 kBias = 4, // Add broadcasted bias vector to the results 118 kBiasThenReLU = kBias | kReLU, // Apply bias and then ReLU transform 119 }; 120 121 // Converts a ComputationType to a string. 122 std::string ComputationTypeString(ComputationType ty); 123 124 template <typename T> 125 struct ToComputationType; 126 template <> 127 struct ToComputationType<float> { 128 static constexpr ComputationType value = ComputationType::kF32; 129 }; 130 template <> 131 struct ToComputationType<double> { 132 static constexpr ComputationType value = ComputationType::kF64; 133 }; 134 template <> 135 struct ToComputationType<Eigen::half> { 136 static constexpr ComputationType value = ComputationType::kF16; 137 }; 138 template <> 139 struct ToComputationType<Eigen::bfloat16> { 140 static constexpr ComputationType value = ComputationType::kBF16AsF32; 141 }; 142 template <> 143 struct ToComputationType<tensorflow::int32> { 144 static constexpr ComputationType value = ComputationType::kI32; 145 }; 146 template <> 147 struct ToComputationType<std::complex<float>> { 148 static constexpr ComputationType value = ComputationType::kComplexF32; 149 }; 150 template <> 151 struct ToComputationType<std::complex<double>> { 152 static constexpr ComputationType value = ComputationType::kComplexF64; 153 }; 154 155 std::ostream &operator<<(std::ostream &os, ComputationType ty); 156 157 using dnn::DataType; 158 using dnn::ToDataType; 159 160 // Describes the type of pointers for the scaling factors alpha and beta in 161 // blaslt routines. 162 enum class PointerMode { 163 kHost, 164 kDevice, 165 }; 166 167 // Converts a ComputationType to a string. 168 std::string DataTypeString(DataType ty); 169 170 std::ostream &operator<<(std::ostream &os, DataType ty); 171 172 // Opaque identifier for an "algorithm" used by a blas routine. This functions 173 // as a hint to the blas library. 174 typedef int64 AlgorithmType; 175 constexpr AlgorithmType kDefaultAlgorithm = -1; 176 constexpr AlgorithmType kDefaultBlasGemm = -2; 177 constexpr AlgorithmType kDefaultBlasGemv = -3; 178 constexpr AlgorithmType kNoAlgorithm = -4; 179 180 // blas uses -1 to represent the default algorithm. This happens to match up 181 // with the CUBLAS_GEMM_DFALT constant, so cuda_blas.cc is using static_cast 182 // to convert from AlgorithmType to cublasGemmAlgo_t, and uses a static_assert 183 // to ensure that this assumption does not break. 184 // If another blas implementation uses a different value for the default 185 // algorithm, then it needs to convert kDefaultGemmAlgo to that value 186 // (e.g. via a function called ToWhateverGemmAlgo). 187 constexpr AlgorithmType kDefaultGemmAlgo = -1; 188 189 // Describes the result of a performance experiment, usually timing the speed of 190 // a particular AlgorithmType. 191 // 192 // If the call we were benchmarking failed (a common occurrence; not all 193 // algorithms are valid for all calls), is_valid() will be false. 194 class ProfileResult { 195 public: 196 bool is_valid() const { return is_valid_; } 197 void set_is_valid(bool val) { is_valid_ = val; } 198 AlgorithmType algorithm() const { return algorithm_; } 199 void set_algorithm(AlgorithmType val) { algorithm_ = val; } 200 float elapsed_time_in_ms() const { return elapsed_time_in_ms_; } 201 void set_elapsed_time_in_ms(float val) { elapsed_time_in_ms_ = val; } 202 203 private: 204 bool is_valid_ = false; 205 AlgorithmType algorithm_ = kDefaultAlgorithm; 206 float elapsed_time_in_ms_ = std::numeric_limits<float>::max(); 207 }; 208 209 class AlgorithmConfig { 210 public: 211 AlgorithmConfig() : algorithm_(kDefaultAlgorithm) {} 212 explicit AlgorithmConfig(AlgorithmType algorithm) : algorithm_(algorithm) {} 213 AlgorithmType algorithm() const { return algorithm_; } 214 void set_algorithm(AlgorithmType val) { algorithm_ = val; } 215 bool operator==(const AlgorithmConfig &other) const { 216 return this->algorithm_ == other.algorithm_; 217 } 218 bool operator!=(const AlgorithmConfig &other) const { 219 return !(*this == other); 220 } 221 std::string ToString() const; 222 223 private: 224 AlgorithmType algorithm_; 225 }; 226 227 struct IBlasLtMatmulPlan { 228 // Returns the data type of the A and B (input) matrices. 229 virtual DataType ab_type() const = 0; 230 // Returns the data type of the C (input/output) matrix. 231 virtual DataType c_type() const = 0; 232 virtual ~IBlasLtMatmulPlan() {} 233 }; 234 235 struct IBlasLtMatmulAlgorithm { 236 virtual ~IBlasLtMatmulAlgorithm() {} 237 // Returns the index of the algorithm within the list returned by 238 // GetBlasLtMatmulAlgorithms. 239 virtual AlgorithmType index() const = 0; 240 // Returns the workspace size required by the algorithm in bytes. 241 virtual size_t workspace_size() const = 0; 242 }; 243 244 // Parameters for the CreateBlasLtMatmulPlan method. 245 struct BlasLtMatmulPlanParams { 246 DataType ab_type; 247 DataType c_type; 248 ComputationType computation_type; 249 PointerMode pointer_mode; 250 Epilogue epilogue; 251 Transpose transa; 252 Transpose transb; 253 uint64 m; 254 uint64 n; 255 uint64 k; 256 int64 lda; 257 int64 ldb; 258 int64 ldc; 259 int batch_count = 1; 260 int64 stride_a = 0; 261 int64 stride_b = 0; 262 int64 stride_c = 0; 263 }; 264 265 // BLAS support interface -- this can be derived from a GPU executor when the 266 // underlying platform has an BLAS library implementation available. See 267 // StreamExecutor::AsBlas(). 268 // 269 // Thread-hostile: CUDA associates a CUDA-context with a particular thread in 270 // the system. Any operation that a user attempts to perform by enqueueing BLAS 271 // operations on a thread not-associated with the CUDA-context has unknown 272 // behavior at the current time; see b/13176597 273 class BlasSupport { 274 public: 275 virtual ~BlasSupport() {} 276 277 // Computes the sum of magnitudes of the vector elements. 278 // result <- |Re x(1)| + |Im x(1)| + |Re x(2)| + |Im x(2)|+ ... + |Re x(n)| 279 // + |Im x(n)|. 280 // Note that Im x(i) = 0 for real types float/double. 281 virtual bool DoBlasAsum(Stream *stream, uint64 elem_count, 282 const DeviceMemory<float> &x, int incx, 283 DeviceMemory<float> *result) = 0; 284 virtual bool DoBlasAsum(Stream *stream, uint64 elem_count, 285 const DeviceMemory<double> &x, int incx, 286 DeviceMemory<double> *result) = 0; 287 virtual bool DoBlasAsum(Stream *stream, uint64 elem_count, 288 const DeviceMemory<std::complex<float>> &x, int incx, 289 DeviceMemory<float> *result) = 0; 290 virtual bool DoBlasAsum(Stream *stream, uint64 elem_count, 291 const DeviceMemory<std::complex<double>> &x, int incx, 292 DeviceMemory<double> *result) = 0; 293 294 // Performs a BLAS y <- ax+y operation. 295 virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha, 296 const DeviceMemory<float> &x, int incx, 297 DeviceMemory<float> *y, int incy) = 0; 298 virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha, 299 const DeviceMemory<double> &x, int incx, 300 DeviceMemory<double> *y, int incy) = 0; 301 virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count, 302 std::complex<float> alpha, 303 const DeviceMemory<std::complex<float>> &x, int incx, 304 DeviceMemory<std::complex<float>> *y, int incy) = 0; 305 virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count, 306 std::complex<double> alpha, 307 const DeviceMemory<std::complex<double>> &x, int incx, 308 DeviceMemory<std::complex<double>> *y, int incy) = 0; 309 310 // Copies vector to another vector: y <- x. 311 virtual bool DoBlasCopy(Stream *stream, uint64 elem_count, 312 const DeviceMemory<float> &x, int incx, 313 DeviceMemory<float> *y, int incy) = 0; 314 virtual bool DoBlasCopy(Stream *stream, uint64 elem_count, 315 const DeviceMemory<double> &x, int incx, 316 DeviceMemory<double> *y, int incy) = 0; 317 virtual bool DoBlasCopy(Stream *stream, uint64 elem_count, 318 const DeviceMemory<std::complex<float>> &x, int incx, 319 DeviceMemory<std::complex<float>> *y, int incy) = 0; 320 virtual bool DoBlasCopy(Stream *stream, uint64 elem_count, 321 const DeviceMemory<std::complex<double>> &x, int incx, 322 DeviceMemory<std::complex<double>> *y, int incy) = 0; 323 324 // Performs a BLAS dot product result <- x . y. 325 virtual bool DoBlasDot(Stream *stream, uint64 elem_count, 326 const DeviceMemory<float> &x, int incx, 327 const DeviceMemory<float> &y, int incy, 328 DeviceMemory<float> *result) = 0; 329 virtual bool DoBlasDot(Stream *stream, uint64 elem_count, 330 const DeviceMemory<double> &x, int incx, 331 const DeviceMemory<double> &y, int incy, 332 DeviceMemory<double> *result) = 0; 333 334 // Performs a BLAS dot product result <- conj(x) . y for complex types. 335 virtual bool DoBlasDotc(Stream *stream, uint64 elem_count, 336 const DeviceMemory<std::complex<float>> &x, int incx, 337 const DeviceMemory<std::complex<float>> &y, int incy, 338 DeviceMemory<std::complex<float>> *result) = 0; 339 virtual bool DoBlasDotc(Stream *stream, uint64 elem_count, 340 const DeviceMemory<std::complex<double>> &x, int incx, 341 const DeviceMemory<std::complex<double>> &y, int incy, 342 DeviceMemory<std::complex<double>> *result) = 0; 343 344 // Performs a BLAS dot product result <- x . y for complex types. Note that 345 // x is unconjugated in this routine. 346 virtual bool DoBlasDotu(Stream *stream, uint64 elem_count, 347 const DeviceMemory<std::complex<float>> &x, int incx, 348 const DeviceMemory<std::complex<float>> &y, int incy, 349 DeviceMemory<std::complex<float>> *result) = 0; 350 virtual bool DoBlasDotu(Stream *stream, uint64 elem_count, 351 const DeviceMemory<std::complex<double>> &x, int incx, 352 const DeviceMemory<std::complex<double>> &y, int incy, 353 DeviceMemory<std::complex<double>> *result) = 0; 354 355 // Computes the Euclidean norm of a vector: result <- ||x||. 356 // See the following link for more information of Euclidean norm: 357 // http://en.wikipedia.org/wiki/Norm_(mathematics)#Euclidean_norm 358 virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count, 359 const DeviceMemory<float> &x, int incx, 360 DeviceMemory<float> *result) = 0; 361 virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count, 362 const DeviceMemory<double> &x, int incx, 363 DeviceMemory<double> *result) = 0; 364 virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count, 365 const DeviceMemory<std::complex<float>> &x, int incx, 366 DeviceMemory<float> *result) = 0; 367 virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count, 368 const DeviceMemory<std::complex<double>> &x, int incx, 369 DeviceMemory<double> *result) = 0; 370 371 // Performs rotation of points in the plane: 372 // x(i) = c*x(i) + s*y(i) 373 // y(i) = c*y(i) - s*x(i). 374 virtual bool DoBlasRot(Stream *stream, uint64 elem_count, 375 DeviceMemory<float> *x, int incx, 376 DeviceMemory<float> *y, int incy, float c, 377 float s) = 0; 378 virtual bool DoBlasRot(Stream *stream, uint64 elem_count, 379 DeviceMemory<double> *x, int incx, 380 DeviceMemory<double> *y, int incy, double c, 381 double s) = 0; 382 virtual bool DoBlasRot(Stream *stream, uint64 elem_count, 383 DeviceMemory<std::complex<float>> *x, int incx, 384 DeviceMemory<std::complex<float>> *y, int incy, 385 float c, float s) = 0; 386 virtual bool DoBlasRot(Stream *stream, uint64 elem_count, 387 DeviceMemory<std::complex<double>> *x, int incx, 388 DeviceMemory<std::complex<double>> *y, int incy, 389 double c, double s) = 0; 390 391 // Computes the parameters for a Givens rotation. 392 // Given the Cartesian coordinates (a, b) of a point, these routines return 393 // the parameters c, s, r, and z associated with the Givens rotation. The 394 // parameters c and s define a unitary matrix such that: 395 // 396 // | c s |.| a | = | r | 397 // | -s c | | b | | 0 | 398 // 399 // The parameter z is defined such that if |a| > |b|, z is s; otherwise if 400 // c is not 0 z is 1/c; otherwise z is 1. 401 virtual bool DoBlasRotg(Stream *stream, DeviceMemory<float> *a, 402 DeviceMemory<float> *b, DeviceMemory<float> *c, 403 DeviceMemory<float> *s) = 0; 404 virtual bool DoBlasRotg(Stream *stream, DeviceMemory<double> *a, 405 DeviceMemory<double> *b, DeviceMemory<double> *c, 406 DeviceMemory<double> *s) = 0; 407 virtual bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<float>> *a, 408 DeviceMemory<std::complex<float>> *b, 409 DeviceMemory<float> *c, 410 DeviceMemory<std::complex<float>> *s) = 0; 411 virtual bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<double>> *a, 412 DeviceMemory<std::complex<double>> *b, 413 DeviceMemory<double> *c, 414 DeviceMemory<std::complex<double>> *s) = 0; 415 416 // Performs modified Givens rotation of points in the plane. 417 // Given two vectors x and y, each vector element of these vectors is replaced 418 // as follows: 419 // 420 // | x(i) | = H | x(i) | 421 // | y(i) | | y(i) | 422 // 423 // for i=1 to n, where H is a modified Givens transformation matrix whose 424 // values are stored in the param[1] through param[4] array. 425 // For more information please Google this routine. 426 virtual bool DoBlasRotm(Stream *stream, uint64 elem_count, 427 DeviceMemory<float> *x, int incx, 428 DeviceMemory<float> *y, int incy, 429 const DeviceMemory<float> ¶m) = 0; 430 virtual bool DoBlasRotm(Stream *stream, uint64 elem_count, 431 DeviceMemory<double> *x, int incx, 432 DeviceMemory<double> *y, int incy, 433 const DeviceMemory<double> ¶m) = 0; 434 435 // Computes the parameters for a modified Givens rotation. 436 // Given Cartesian coordinates (x1, y1) of an input vector, these routines 437 // compute the components of a modified Givens transformation matrix H that 438 // zeros the y-component of the resulting vector: 439 // 440 // | x1 | = H | x1 * sqrt(d1) | 441 // | 0 | | y1 * sqrt(d1) | 442 // 443 // For more information please Google this routine. 444 virtual bool DoBlasRotmg(Stream *stream, DeviceMemory<float> *d1, 445 DeviceMemory<float> *d2, DeviceMemory<float> *x1, 446 const DeviceMemory<float> &y1, 447 DeviceMemory<float> *param) = 0; 448 virtual bool DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1, 449 DeviceMemory<double> *d2, DeviceMemory<double> *x1, 450 const DeviceMemory<double> &y1, 451 DeviceMemory<double> *param) = 0; 452 453 // Computes the product of a vector by a scalar: x <- a*x. 454 virtual bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha, 455 DeviceMemory<float> *x, int incx) = 0; 456 virtual bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha, 457 DeviceMemory<double> *x, int incx) = 0; 458 virtual bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha, 459 DeviceMemory<std::complex<float>> *x, int incx) = 0; 460 virtual bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha, 461 DeviceMemory<std::complex<double>> *x, int incx) = 0; 462 virtual bool DoBlasScal(Stream *stream, uint64 elem_count, 463 std::complex<float> alpha, 464 DeviceMemory<std::complex<float>> *x, int incx) = 0; 465 virtual bool DoBlasScal(Stream *stream, uint64 elem_count, 466 std::complex<double> alpha, 467 DeviceMemory<std::complex<double>> *x, int incx) = 0; 468 469 // Swaps a vector with another vector. 470 virtual bool DoBlasSwap(Stream *stream, uint64 elem_count, 471 DeviceMemory<float> *x, int incx, 472 DeviceMemory<float> *y, int incy) = 0; 473 virtual bool DoBlasSwap(Stream *stream, uint64 elem_count, 474 DeviceMemory<double> *x, int incx, 475 DeviceMemory<double> *y, int incy) = 0; 476 virtual bool DoBlasSwap(Stream *stream, uint64 elem_count, 477 DeviceMemory<std::complex<float>> *x, int incx, 478 DeviceMemory<std::complex<float>> *y, int incy) = 0; 479 virtual bool DoBlasSwap(Stream *stream, uint64 elem_count, 480 DeviceMemory<std::complex<double>> *x, int incx, 481 DeviceMemory<std::complex<double>> *y, int incy) = 0; 482 483 // Finds the index of the element with maximum absolute value. 484 virtual bool DoBlasIamax(Stream *stream, uint64 elem_count, 485 const DeviceMemory<float> &x, int incx, 486 DeviceMemory<int> *result) = 0; 487 virtual bool DoBlasIamax(Stream *stream, uint64 elem_count, 488 const DeviceMemory<double> &x, int incx, 489 DeviceMemory<int> *result) = 0; 490 virtual bool DoBlasIamax(Stream *stream, uint64 elem_count, 491 const DeviceMemory<std::complex<float>> &x, int incx, 492 DeviceMemory<int> *result) = 0; 493 virtual bool DoBlasIamax(Stream *stream, uint64 elem_count, 494 const DeviceMemory<std::complex<double>> &x, 495 int incx, DeviceMemory<int> *result) = 0; 496 497 // Finds the index of the element with minimum absolute value. 498 virtual bool DoBlasIamin(Stream *stream, uint64 elem_count, 499 const DeviceMemory<float> &x, int incx, 500 DeviceMemory<int> *result) = 0; 501 virtual bool DoBlasIamin(Stream *stream, uint64 elem_count, 502 const DeviceMemory<double> &x, int incx, 503 DeviceMemory<int> *result) = 0; 504 virtual bool DoBlasIamin(Stream *stream, uint64 elem_count, 505 const DeviceMemory<std::complex<float>> &x, int incx, 506 DeviceMemory<int> *result) = 0; 507 virtual bool DoBlasIamin(Stream *stream, uint64 elem_count, 508 const DeviceMemory<std::complex<double>> &x, 509 int incx, DeviceMemory<int> *result) = 0; 510 511 // Computes a matrix-vector product using a general band matrix: 512 // 513 // y <- alpha * a * x + beta * y, 514 // or 515 // y <- alpha * a' * x + beta * y, 516 // or 517 // y <- alpha * conj(a') * x + beta * y, 518 // 519 // alpha and beta are scalars; a is an m-by-n general band matrix, with kl 520 // sub-diagonals and ku super-diagonals; x is a vector with 521 // n(trans==kNoTranspose)/m(otherwise) elements; 522 // y is a vector with m(trans==kNoTranspose)/n(otherwise) elements. 523 virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, 524 uint64 n, uint64 kl, uint64 ku, float alpha, 525 const DeviceMemory<float> &a, int lda, 526 const DeviceMemory<float> &x, int incx, float beta, 527 DeviceMemory<float> *y, int incy) = 0; 528 virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, 529 uint64 n, uint64 kl, uint64 ku, double alpha, 530 const DeviceMemory<double> &a, int lda, 531 const DeviceMemory<double> &x, int incx, double beta, 532 DeviceMemory<double> *y, int incy) = 0; 533 virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, 534 uint64 n, uint64 kl, uint64 ku, 535 std::complex<float> alpha, 536 const DeviceMemory<std::complex<float>> &a, int lda, 537 const DeviceMemory<std::complex<float>> &x, int incx, 538 std::complex<float> beta, 539 DeviceMemory<std::complex<float>> *y, int incy) = 0; 540 virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, 541 uint64 n, uint64 kl, uint64 ku, 542 std::complex<double> alpha, 543 const DeviceMemory<std::complex<double>> &a, int lda, 544 const DeviceMemory<std::complex<double>> &x, int incx, 545 std::complex<double> beta, 546 DeviceMemory<std::complex<double>> *y, int incy) = 0; 547 548 // Computes a matrix-vector product using a general matrix. 549 // 550 // y <- alpha * a * x + beta * y, 551 // or 552 // y <- alpha * a' * x + beta * y, 553 // or 554 // y <- alpha * conj(a') * x + beta * y, 555 // 556 // alpha and beta are scalars; a is an m-by-n general matrix; x is a vector 557 // with n(trans==kNoTranspose)/m(otherwise) elements; 558 // y is a vector with m(trans==kNoTranspose)/n(otherwise) elements. 559 virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, 560 uint64 n, float alpha, const DeviceMemory<float> &a, 561 int lda, const DeviceMemory<float> &x, int incx, 562 float beta, DeviceMemory<float> *y, int incy) = 0; 563 virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, 564 uint64 n, double alpha, const DeviceMemory<double> &a, 565 int lda, const DeviceMemory<double> &x, int incx, 566 double beta, DeviceMemory<double> *y, int incy) = 0; 567 virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, 568 uint64 n, std::complex<float> alpha, 569 const DeviceMemory<std::complex<float>> &a, int lda, 570 const DeviceMemory<std::complex<float>> &x, int incx, 571 std::complex<float> beta, 572 DeviceMemory<std::complex<float>> *y, int incy) = 0; 573 virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, 574 uint64 n, std::complex<double> alpha, 575 const DeviceMemory<std::complex<double>> &a, int lda, 576 const DeviceMemory<std::complex<double>> &x, int incx, 577 std::complex<double> beta, 578 DeviceMemory<std::complex<double>> *y, int incy) = 0; 579 580 virtual bool DoBlasGemvWithProfiling( 581 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha, 582 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x, 583 int incx, float beta, DeviceMemory<float> *y, int incy, 584 ProfileResult *output_profile_result) = 0; 585 virtual bool DoBlasGemvWithProfiling( 586 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha, 587 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x, 588 int incx, double beta, DeviceMemory<double> *y, int incy, 589 ProfileResult *output_profile_result) = 0; 590 virtual bool DoBlasGemvWithProfiling( 591 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, 592 std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a, 593 int lda, const DeviceMemory<std::complex<float>> &x, int incx, 594 std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy, 595 ProfileResult *output_profile_result) = 0; 596 virtual bool DoBlasGemvWithProfiling( 597 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, 598 std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a, 599 int lda, const DeviceMemory<std::complex<double>> &x, int incx, 600 std::complex<double> beta, DeviceMemory<std::complex<double>> *y, 601 int incy, ProfileResult *output_profile_result) = 0; 602 603 // Performs a rank-1 update of a general matrix. 604 // 605 // a <- alpha * x * y' + a, 606 // 607 // alpha is a scalar; x is an m-element vector; y is an n-element vector; a is 608 // an m-by-n general matrix. 609 virtual bool DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha, 610 const DeviceMemory<float> &x, int incx, 611 const DeviceMemory<float> &y, int incy, 612 DeviceMemory<float> *a, int lda) = 0; 613 virtual bool DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha, 614 const DeviceMemory<double> &x, int incx, 615 const DeviceMemory<double> &y, int incy, 616 DeviceMemory<double> *a, int lda) = 0; 617 618 // Performs a rank-1 update (conjugated) of a general matrix. 619 // 620 // a <- alpha * x * conj(y') + a, 621 // 622 // alpha is a scalar; x is an m-element vector; y is an n-element vector; a is 623 // an m-by-n general matrix. 624 virtual bool DoBlasGerc(Stream *stream, uint64 m, uint64 n, 625 std::complex<float> alpha, 626 const DeviceMemory<std::complex<float>> &x, int incx, 627 const DeviceMemory<std::complex<float>> &y, int incy, 628 DeviceMemory<std::complex<float>> *a, int lda) = 0; 629 virtual bool DoBlasGerc(Stream *stream, uint64 m, uint64 n, 630 std::complex<double> alpha, 631 const DeviceMemory<std::complex<double>> &x, int incx, 632 const DeviceMemory<std::complex<double>> &y, int incy, 633 DeviceMemory<std::complex<double>> *a, int lda) = 0; 634 635 // Performs a rank-1 update (unconjugated) of a general matrix. 636 // 637 // a <- alpha * x * y' + a, 638 // 639 // alpha is a scalar; x is an m-element vector; y is an n-element vector; a is 640 // an m-by-n general matrix. 641 virtual bool DoBlasGeru(Stream *stream, uint64 m, uint64 n, 642 std::complex<float> alpha, 643 const DeviceMemory<std::complex<float>> &x, int incx, 644 const DeviceMemory<std::complex<float>> &y, int incy, 645 DeviceMemory<std::complex<float>> *a, int lda) = 0; 646 virtual bool DoBlasGeru(Stream *stream, uint64 m, uint64 n, 647 std::complex<double> alpha, 648 const DeviceMemory<std::complex<double>> &x, int incx, 649 const DeviceMemory<std::complex<double>> &y, int incy, 650 DeviceMemory<std::complex<double>> *a, int lda) = 0; 651 652 // Computes a matrix-vector product using a Hermitian band matrix. 653 // 654 // y <- alpha * a * x + beta * y, 655 // 656 // alpha and beta are scalars; a is an n-by-n Hermitian band matrix, with k 657 // super-diagonals; x and y are n-element vectors. 658 virtual bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, 659 uint64 k, std::complex<float> alpha, 660 const DeviceMemory<std::complex<float>> &a, int lda, 661 const DeviceMemory<std::complex<float>> &x, int incx, 662 std::complex<float> beta, 663 DeviceMemory<std::complex<float>> *y, int incy) = 0; 664 virtual bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, 665 uint64 k, std::complex<double> alpha, 666 const DeviceMemory<std::complex<double>> &a, int lda, 667 const DeviceMemory<std::complex<double>> &x, int incx, 668 std::complex<double> beta, 669 DeviceMemory<std::complex<double>> *y, int incy) = 0; 670 671 // Computes a matrix-vector product using a Hermitian matrix. 672 // 673 // y <- alpha * a * x + beta * y, 674 // 675 // alpha and beta are scalars; a is an n-by-n Hermitian matrix; x and y are 676 // n-element vectors. 677 virtual bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n, 678 std::complex<float> alpha, 679 const DeviceMemory<std::complex<float>> &a, int lda, 680 const DeviceMemory<std::complex<float>> &x, int incx, 681 std::complex<float> beta, 682 DeviceMemory<std::complex<float>> *y, int incy) = 0; 683 virtual bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n, 684 std::complex<double> alpha, 685 const DeviceMemory<std::complex<double>> &a, int lda, 686 const DeviceMemory<std::complex<double>> &x, int incx, 687 std::complex<double> beta, 688 DeviceMemory<std::complex<double>> *y, int incy) = 0; 689 690 // Performs a rank-1 update of a Hermitian matrix. 691 // 692 // a <- alpha * x * conj(x') + a, 693 // 694 // alpha is a scalar; x is an n-element vector; a is an n-by-n Hermitian 695 // matrix. 696 virtual bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, 697 float alpha, 698 const DeviceMemory<std::complex<float>> &x, int incx, 699 DeviceMemory<std::complex<float>> *a, int lda) = 0; 700 virtual bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, 701 double alpha, 702 const DeviceMemory<std::complex<double>> &x, int incx, 703 DeviceMemory<std::complex<double>> *a, int lda) = 0; 704 705 // Performs a rank-2 update of a Hermitian matrix. 706 // 707 // a <- alpha * x * conj(x') + conj(alpha) * y * conj(x') + a, 708 // 709 // alpha is a scalar; x and y are n-element vectors; a is an n-by-n Hermitian 710 // matrix. 711 virtual bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n, 712 std::complex<float> alpha, 713 const DeviceMemory<std::complex<float>> &x, int incx, 714 const DeviceMemory<std::complex<float>> &y, int incy, 715 DeviceMemory<std::complex<float>> *a, int lda) = 0; 716 virtual bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n, 717 std::complex<double> alpha, 718 const DeviceMemory<std::complex<double>> &x, int incx, 719 const DeviceMemory<std::complex<double>> &y, int incy, 720 DeviceMemory<std::complex<double>> *a, int lda) = 0; 721 722 // Computes a matrix-vector product using a Hermitian packed matrix. 723 // 724 // y <- alpha * a * x + beta * y, 725 // 726 // alpha and beta are scalars; a is an n-by-n Hermitian matrix, supplied in 727 // packed form; x and y are n-element vectors. 728 virtual bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n, 729 std::complex<float> alpha, 730 const DeviceMemory<std::complex<float>> &ap, 731 const DeviceMemory<std::complex<float>> &x, int incx, 732 std::complex<float> beta, 733 DeviceMemory<std::complex<float>> *y, int incy) = 0; 734 virtual bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n, 735 std::complex<double> alpha, 736 const DeviceMemory<std::complex<double>> &ap, 737 const DeviceMemory<std::complex<double>> &x, int incx, 738 std::complex<double> beta, 739 DeviceMemory<std::complex<double>> *y, int incy) = 0; 740 741 // Performs a rank-1 update of a Hermitian packed matrix. 742 // 743 // a <- alpha * x * conj(x') + a, 744 // 745 // alpha is a scalar; x is an n-element vector; a is an n-by-n Hermitian 746 // matrix, supplied in packed form. 747 virtual bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, 748 float alpha, 749 const DeviceMemory<std::complex<float>> &x, int incx, 750 DeviceMemory<std::complex<float>> *ap) = 0; 751 virtual bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, 752 double alpha, 753 const DeviceMemory<std::complex<double>> &x, int incx, 754 DeviceMemory<std::complex<double>> *ap) = 0; 755 756 // Performs a rank-2 update of a Hermitian packed matrix. 757 // 758 // a <- alpha * x * conj(x') + conj(alpha) * y * conj(x') + a, 759 // 760 // alpha is a scalar; x and y are n-element vectors; a is an n-by-n Hermitian 761 // matrix, supplied in packed form. 762 virtual bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n, 763 std::complex<float> alpha, 764 const DeviceMemory<std::complex<float>> &x, int incx, 765 const DeviceMemory<std::complex<float>> &y, int incy, 766 DeviceMemory<std::complex<float>> *ap) = 0; 767 virtual bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n, 768 std::complex<double> alpha, 769 const DeviceMemory<std::complex<double>> &x, int incx, 770 const DeviceMemory<std::complex<double>> &y, int incy, 771 DeviceMemory<std::complex<double>> *ap) = 0; 772 773 // Computes a matrix-vector product using a symmetric band matrix. 774 // 775 // y <- alpha * a * x + beta * y, 776 // 777 // alpha and beta are scalars; a is an n-by-n symmetric band matrix, with k 778 // super-diagonals; x and y are n-element vectors. 779 virtual bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, 780 uint64 k, float alpha, const DeviceMemory<float> &a, 781 int lda, const DeviceMemory<float> &x, int incx, 782 float beta, DeviceMemory<float> *y, int incy) = 0; 783 virtual bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, 784 uint64 k, double alpha, const DeviceMemory<double> &a, 785 int lda, const DeviceMemory<double> &x, int incx, 786 double beta, DeviceMemory<double> *y, int incy) = 0; 787 788 // Computes a matrix-vector product using a symmetric packed matrix. 789 // 790 // y <- alpha * a * x + beta * y, 791 // 792 // alpha and beta are scalars; a is an n-by-n symmetric matrix, supplied in 793 // packed form; x and y are n-element vectors. 794 virtual bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n, 795 float alpha, const DeviceMemory<float> &ap, 796 const DeviceMemory<float> &x, int incx, float beta, 797 DeviceMemory<float> *y, int incy) = 0; 798 virtual bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n, 799 double alpha, const DeviceMemory<double> &ap, 800 const DeviceMemory<double> &x, int incx, double beta, 801 DeviceMemory<double> *y, int incy) = 0; 802 803 // Performs a rank-1 update of a symmetric packed matrix. 804 // 805 // a <- alpha * x * x' + a, 806 // 807 // alpha is a scalar; x is an n-element vector; a is an n-by-n symmetric 808 // matrix, supplied in packed form. 809 virtual bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, 810 float alpha, const DeviceMemory<float> &x, int incx, 811 DeviceMemory<float> *ap) = 0; 812 virtual bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, 813 double alpha, const DeviceMemory<double> &x, int incx, 814 DeviceMemory<double> *ap) = 0; 815 816 // Performs a rank-2 update of a symmetric packed matrix. 817 // 818 // a <- alpha * x * x' + alpha * y * x' + a, 819 // 820 // alpha is a scalar; x and y are n-element vectors; a is an n-by-n symmetric 821 // matrix, supplied in packed form. 822 virtual bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n, 823 float alpha, const DeviceMemory<float> &x, int incx, 824 const DeviceMemory<float> &y, int incy, 825 DeviceMemory<float> *ap) = 0; 826 virtual bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n, 827 double alpha, const DeviceMemory<double> &x, int incx, 828 const DeviceMemory<double> &y, int incy, 829 DeviceMemory<double> *ap) = 0; 830 831 // Computes a matrix-vector product for a symmetric matrix. 832 // 833 // y <- alpha * a * x + beta * y, 834 // 835 // alpha and beta are scalars; a is an n-by-n symmetric matrix; x and y are 836 // n-element vectors. 837 virtual bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n, 838 float alpha, const DeviceMemory<float> &a, int lda, 839 const DeviceMemory<float> &x, int incx, float beta, 840 DeviceMemory<float> *y, int incy) = 0; 841 virtual bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n, 842 double alpha, const DeviceMemory<double> &a, int lda, 843 const DeviceMemory<double> &x, int incx, double beta, 844 DeviceMemory<double> *y, int incy) = 0; 845 846 // Performs a rank-1 update of a symmetric matrix. 847 // 848 // a <- alpha * x * x' + a, 849 // 850 // alpha is a scalar; x is an n-element vector; a is an n-by-n symmetric 851 // matrix. 852 virtual bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n, 853 float alpha, const DeviceMemory<float> &x, int incx, 854 DeviceMemory<float> *a, int lda) = 0; 855 virtual bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n, 856 double alpha, const DeviceMemory<double> &x, int incx, 857 DeviceMemory<double> *a, int lda) = 0; 858 859 // Performs a rank-2 update of symmetric matrix. 860 // 861 // a <- alpha * x * x' + alpha * y * x' + a, 862 // 863 // alpha is a scalar; x and y are n-element vectors; a is an n-by-n symmetric 864 // matrix. 865 virtual bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n, 866 float alpha, const DeviceMemory<float> &x, int incx, 867 const DeviceMemory<float> &y, int incy, 868 DeviceMemory<float> *a, int lda) = 0; 869 virtual bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n, 870 double alpha, const DeviceMemory<double> &x, int incx, 871 const DeviceMemory<double> &y, int incy, 872 DeviceMemory<double> *a, int lda) = 0; 873 874 // Computes a matrix-vector product using a triangular band matrix. 875 // 876 // x <- a * x, 877 // or 878 // x <- a' * x, 879 // or 880 // x <- conj(a') * x, 881 // 882 // a is an n-by-n unit, or non-unit, upper or lower triangular band matrix, 883 // with k+1 diagonals; x is a n-element vector. 884 virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, 885 blas::Transpose trans, blas::Diagonal diag, uint64 n, 886 uint64 k, const DeviceMemory<float> &a, int lda, 887 DeviceMemory<float> *x, int incx) = 0; 888 virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, 889 blas::Transpose trans, blas::Diagonal diag, uint64 n, 890 uint64 k, const DeviceMemory<double> &a, int lda, 891 DeviceMemory<double> *x, int incx) = 0; 892 virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, 893 blas::Transpose trans, blas::Diagonal diag, uint64 n, 894 uint64 k, const DeviceMemory<std::complex<float>> &a, 895 int lda, DeviceMemory<std::complex<float>> *x, 896 int incx) = 0; 897 virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, 898 blas::Transpose trans, blas::Diagonal diag, uint64 n, 899 uint64 k, const DeviceMemory<std::complex<double>> &a, 900 int lda, DeviceMemory<std::complex<double>> *x, 901 int incx) = 0; 902 903 // Solves a system of linear equations whose coefficients are in a triangular 904 // band matrix as below: 905 // 906 // a * x = b, 907 // or 908 // a' * x = b, 909 // or 910 // conj(a') * x = b, 911 // 912 // b and x are n-element vectors; a is an n-by-n unit, or non-unit, upper or 913 // lower triangular band matrix, with k+1 diagonals. 914 virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, 915 blas::Transpose trans, blas::Diagonal diag, uint64 n, 916 uint64 k, const DeviceMemory<float> &a, int lda, 917 DeviceMemory<float> *x, int incx) = 0; 918 virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, 919 blas::Transpose trans, blas::Diagonal diag, uint64 n, 920 uint64 k, const DeviceMemory<double> &a, int lda, 921 DeviceMemory<double> *x, int incx) = 0; 922 virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, 923 blas::Transpose trans, blas::Diagonal diag, uint64 n, 924 uint64 k, const DeviceMemory<std::complex<float>> &a, 925 int lda, DeviceMemory<std::complex<float>> *x, 926 int incx) = 0; 927 virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, 928 blas::Transpose trans, blas::Diagonal diag, uint64 n, 929 uint64 k, const DeviceMemory<std::complex<double>> &a, 930 int lda, DeviceMemory<std::complex<double>> *x, 931 int incx) = 0; 932 933 // Computes a matrix-vector product using a triangular packed matrix. 934 // 935 // x <- a * x, 936 // or 937 // x <- a' * x, 938 // or 939 // x <- conj(a') * x, 940 // 941 // a is an n-by-n unit, or non-unit, upper or lower triangular matrix, 942 // supplied in packed form; x is a n-element vector. 943 virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, 944 blas::Transpose trans, blas::Diagonal diag, uint64 n, 945 const DeviceMemory<float> &ap, DeviceMemory<float> *x, 946 int incx) = 0; 947 virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, 948 blas::Transpose trans, blas::Diagonal diag, uint64 n, 949 const DeviceMemory<double> &ap, 950 DeviceMemory<double> *x, int incx) = 0; 951 virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, 952 blas::Transpose trans, blas::Diagonal diag, uint64 n, 953 const DeviceMemory<std::complex<float>> &ap, 954 DeviceMemory<std::complex<float>> *x, int incx) = 0; 955 virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, 956 blas::Transpose trans, blas::Diagonal diag, uint64 n, 957 const DeviceMemory<std::complex<double>> &ap, 958 DeviceMemory<std::complex<double>> *x, int incx) = 0; 959 960 // Solves a system of linear equations whose coefficients are in a triangular 961 // packed matrix as below: 962 // 963 // a * x = b, 964 // or 965 // a' * x = b, 966 // or 967 // conj(a') * x = b, 968 // 969 // b and x are n-element vectors; a is an n-by-n unit, or non-unit, upper or 970 // lower triangular matrix, supplied in packed form. 971 virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, 972 blas::Transpose trans, blas::Diagonal diag, uint64 n, 973 const DeviceMemory<float> &ap, DeviceMemory<float> *x, 974 int incx) = 0; 975 virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, 976 blas::Transpose trans, blas::Diagonal diag, uint64 n, 977 const DeviceMemory<double> &ap, 978 DeviceMemory<double> *x, int incx) = 0; 979 virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, 980 blas::Transpose trans, blas::Diagonal diag, uint64 n, 981 const DeviceMemory<std::complex<float>> &ap, 982 DeviceMemory<std::complex<float>> *x, int incx) = 0; 983 virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, 984 blas::Transpose trans, blas::Diagonal diag, uint64 n, 985 const DeviceMemory<std::complex<double>> &ap, 986 DeviceMemory<std::complex<double>> *x, int incx) = 0; 987 988 // Computes a matrix-vector product using a triangular matrix. 989 // 990 // x <- a * x, 991 // or 992 // x <- a' * x, 993 // or 994 // x <- conj(a') * x, 995 // 996 // a is an n-by-n unit, or non-unit, upper or lower triangular matrix; x is a 997 // n-element vector. 998 virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, 999 blas::Transpose trans, blas::Diagonal diag, uint64 n, 1000 const DeviceMemory<float> &a, int lda, 1001 DeviceMemory<float> *x, int incx) = 0; 1002 virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, 1003 blas::Transpose trans, blas::Diagonal diag, uint64 n, 1004 const DeviceMemory<double> &a, int lda, 1005 DeviceMemory<double> *x, int incx) = 0; 1006 virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, 1007 blas::Transpose trans, blas::Diagonal diag, uint64 n, 1008 const DeviceMemory<std::complex<float>> &a, int lda, 1009 DeviceMemory<std::complex<float>> *x, int incx) = 0; 1010 virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, 1011 blas::Transpose trans, blas::Diagonal diag, uint64 n, 1012 const DeviceMemory<std::complex<double>> &a, int lda, 1013 DeviceMemory<std::complex<double>> *x, int incx) = 0; 1014 1015 // Solves a system of linear equations whose coefficients are in a triangular 1016 // matrix as below: 1017 // 1018 // a * x = b, 1019 // or 1020 // a' * x = b, 1021 // or 1022 // conj(a') * x = b, 1023 // 1024 // b and x are n-element vectors; a is an n-by-n unit, or non-unit, upper or 1025 // lower triangular matrix. 1026 virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, 1027 blas::Transpose trans, blas::Diagonal diag, uint64 n, 1028 const DeviceMemory<float> &a, int lda, 1029 DeviceMemory<float> *x, int incx) = 0; 1030 virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, 1031 blas::Transpose trans, blas::Diagonal diag, uint64 n, 1032 const DeviceMemory<double> &a, int lda, 1033 DeviceMemory<double> *x, int incx) = 0; 1034 virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, 1035 blas::Transpose trans, blas::Diagonal diag, uint64 n, 1036 const DeviceMemory<std::complex<float>> &a, int lda, 1037 DeviceMemory<std::complex<float>> *x, int incx) = 0; 1038 virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, 1039 blas::Transpose trans, blas::Diagonal diag, uint64 n, 1040 const DeviceMemory<std::complex<double>> &a, int lda, 1041 DeviceMemory<std::complex<double>> *x, int incx) = 0; 1042 1043 // Computes a matrix-matrix product with general matrices: 1044 // 1045 // c <- alpha * op(a) * op(b) + beta * c, 1046 // 1047 // op(X) is one of op(X) = X, or op(X) = X', or op(X) = conj(X'); alpha and 1048 // beta are scalars; a, b, and c are matrices; op(a) is an m-by-k matrix; 1049 // op(b) is a k-by-n matrix; c is an m-by-n matrix. 1050 // 1051 // Note: The half interface uses float precision internally; the version 1052 // that uses half precision internally is not yet supported. There is no 1053 // batched version of the half-precision interface. 1054 // 1055 // Alpha/beta type matches `dtype`, unless `dtype` is `Eigen::half`, in that 1056 // case the expected alpha/beta type is `float`. 1057 virtual port::Status DoBlasGemm(Stream *stream, blas::Transpose transa, 1058 blas::Transpose transb, uint64 m, uint64 n, 1059 uint64 k, DataType dtype, const void *alpha, 1060 const DeviceMemoryBase &a, int lda, 1061 const DeviceMemoryBase &b, int ldb, 1062 const void *beta, DeviceMemoryBase *c, 1063 int ldc) = 0; 1064 1065 virtual bool DoBlasGemmWithProfiling( 1066 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1067 uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, 1068 int lda, const DeviceMemory<Eigen::half> &b, int ldb, float beta, 1069 DeviceMemory<Eigen::half> *c, int ldc, 1070 ProfileResult *output_profile_result) = 0; 1071 virtual bool DoBlasGemmWithProfiling( 1072 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1073 uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda, 1074 const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c, 1075 int ldc, ProfileResult *output_profile_result) = 0; 1076 virtual bool DoBlasGemmWithProfiling( 1077 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1078 uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda, 1079 const DeviceMemory<double> &b, int ldb, double beta, 1080 DeviceMemory<double> *c, int ldc, 1081 ProfileResult *output_profile_result) = 0; 1082 virtual bool DoBlasGemmWithProfiling( 1083 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1084 uint64 n, uint64 k, std::complex<float> alpha, 1085 const DeviceMemory<std::complex<float>> &a, int lda, 1086 const DeviceMemory<std::complex<float>> &b, int ldb, 1087 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, 1088 ProfileResult *output_profile_result) = 0; 1089 virtual bool DoBlasGemmWithProfiling( 1090 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1091 uint64 n, uint64 k, std::complex<double> alpha, 1092 const DeviceMemory<std::complex<double>> &a, int lda, 1093 const DeviceMemory<std::complex<double>> &b, int ldb, 1094 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, 1095 ProfileResult *output_profile_result) = 0; 1096 1097 // Gets a list of supported algorithms for DoBlasGemmWithAlgorithm. 1098 virtual bool GetBlasGemmAlgorithms( 1099 std::vector<AlgorithmType> *out_algorithms) = 0; 1100 1101 // Like DoBlasGemm, but accepts an algorithm and an compute type. 1102 // 1103 // The compute type lets you say (e.g.) that the inputs and outputs are 1104 // Eigen::halfs, but you want the internal computations to be done with 1105 // float32 precision. 1106 // 1107 // If output_profile_result is not null, a failure here does not put the 1108 // stream in a failure state. Instead, success/failure is indicated by 1109 // output_profile_result->is_valid(). This lets you use this function for 1110 // choosing the best algorithm among many (some of which may fail) without 1111 // creating a new Stream for each attempt. 1112 virtual port::Status DoBlasGemmWithAlgorithm( 1113 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1114 uint64 n, uint64 k, const void *alpha, const DeviceMemoryBase &a, 1115 DataType type_a, int lda, const DeviceMemoryBase &b, DataType type_b, 1116 int ldb, const void *beta, DeviceMemoryBase *c, DataType type_c, int ldc, 1117 ComputationType computation_type, AlgorithmType algorithm, 1118 ProfileResult *output_profile_result) = 0; 1119 1120 virtual port::Status DoBlasGemmStridedBatchedWithAlgorithm( 1121 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1122 uint64 n, uint64 k, const void *alpha, const DeviceMemoryBase &a, 1123 DataType type_a, int lda, int64_t stride_a, const DeviceMemoryBase &b, 1124 DataType type_b, int ldb, int64_t stride_b, const void *beta, 1125 DeviceMemoryBase *c, DataType type_c, int ldc, int64_t stride_c, 1126 int batch_count, ComputationType computation_type, 1127 AlgorithmType algorithm, ProfileResult *output_profile_result) = 0; 1128 1129 // Computes a batch of matrix-matrix product with general matrices. 1130 // This is a batched version of DoBlasGemm. 1131 // The batched GEMM computes matrix product for each input/output in a, b, 1132 // and c, which contain batch_count DeviceMemory objects. 1133 virtual bool DoBlasGemmBatched( 1134 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1135 uint64 n, uint64 k, float alpha, 1136 const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda, 1137 const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, 1138 float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, 1139 int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0; 1140 virtual bool DoBlasGemmBatched( 1141 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1142 uint64 n, uint64 k, float alpha, 1143 const port::ArraySlice<DeviceMemory<float> *> &a, int lda, 1144 const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, float beta, 1145 const port::ArraySlice<DeviceMemory<float> *> &c, int ldc, 1146 int batch_count, ScratchAllocator *scratch_allocator) = 0; 1147 virtual bool DoBlasGemmBatched( 1148 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1149 uint64 n, uint64 k, double alpha, 1150 const port::ArraySlice<DeviceMemory<double> *> &a, int lda, 1151 const port::ArraySlice<DeviceMemory<double> *> &b, int ldb, double beta, 1152 const port::ArraySlice<DeviceMemory<double> *> &c, int ldc, 1153 int batch_count, ScratchAllocator *scratch_allocator) = 0; 1154 virtual bool DoBlasGemmBatched( 1155 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1156 uint64 n, uint64 k, std::complex<float> alpha, 1157 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda, 1158 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb, 1159 std::complex<float> beta, 1160 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc, 1161 int batch_count, ScratchAllocator *scratch_allocator) = 0; 1162 virtual bool DoBlasGemmBatched( 1163 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1164 uint64 n, uint64 k, std::complex<double> alpha, 1165 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda, 1166 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb, 1167 std::complex<double> beta, 1168 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc, 1169 int batch_count, ScratchAllocator *scratch_allocator) = 0; 1170 1171 // Batched gemm with strides instead of pointer arrays. 1172 virtual port::Status DoBlasGemmStridedBatched( 1173 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1174 uint64 n, uint64 k, DataType dtype, const void *alpha, 1175 const DeviceMemoryBase &a, int lda, int64_t stride_a, 1176 const DeviceMemoryBase &b, int ldb, int64_t stride_b, const void *beta, 1177 DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count) = 0; 1178 1179 // Computes a matrix-matrix product where one input matrix is Hermitian: 1180 // 1181 // c <- alpha * a * b + beta * c, 1182 // or 1183 // c <- alpha * b * a + beta * c, 1184 // 1185 // alpha and beta are scalars; a is a Hermitian matrix; b and c are m-by-n 1186 // matrices. 1187 virtual bool DoBlasHemm(Stream *stream, blas::Side side, 1188 blas::UpperLower uplo, uint64 m, uint64 n, 1189 std::complex<float> alpha, 1190 const DeviceMemory<std::complex<float>> &a, int lda, 1191 const DeviceMemory<std::complex<float>> &b, int ldb, 1192 std::complex<float> beta, 1193 DeviceMemory<std::complex<float>> *c, int ldc) = 0; 1194 virtual bool DoBlasHemm(Stream *stream, blas::Side side, 1195 blas::UpperLower uplo, uint64 m, uint64 n, 1196 std::complex<double> alpha, 1197 const DeviceMemory<std::complex<double>> &a, int lda, 1198 const DeviceMemory<std::complex<double>> &b, int ldb, 1199 std::complex<double> beta, 1200 DeviceMemory<std::complex<double>> *c, int ldc) = 0; 1201 1202 // Performs a Hermitian rank-k update. 1203 // 1204 // c <- alpha * a * conj(a') + beta * c, 1205 // or 1206 // c <- alpha * conj(a') * a + beta * c, 1207 // 1208 // alpha and beta are scalars; c is a n-by-n Hermitian matrix; a is an n-by-k 1209 // matrix in the first case and a k-by-n matrix in the second case. 1210 virtual bool DoBlasHerk(Stream *stream, blas::UpperLower uplo, 1211 blas::Transpose trans, uint64 n, uint64 k, 1212 float alpha, 1213 const DeviceMemory<std::complex<float>> &a, int lda, 1214 float beta, DeviceMemory<std::complex<float>> *c, 1215 int ldc) = 0; 1216 virtual bool DoBlasHerk(Stream *stream, blas::UpperLower uplo, 1217 blas::Transpose trans, uint64 n, uint64 k, 1218 double alpha, 1219 const DeviceMemory<std::complex<double>> &a, int lda, 1220 double beta, DeviceMemory<std::complex<double>> *c, 1221 int ldc) = 0; 1222 1223 // Performs a Hermitian rank-2k update. 1224 // 1225 // c <- alpha * a * conj(b') + conj(alpha) * b * conj(a') + beta * c, 1226 // or 1227 // c <- alpha * conj(b') * a + conj(alpha) * conj(a') * b + beta * c, 1228 // 1229 // alpha and beta are scalars; c is a n-by-n Hermitian matrix; a and b are 1230 // n-by-k matrices in the first case and k-by-n matrices in the second case. 1231 virtual bool DoBlasHer2k(Stream *stream, blas::UpperLower uplo, 1232 blas::Transpose trans, uint64 n, uint64 k, 1233 std::complex<float> alpha, 1234 const DeviceMemory<std::complex<float>> &a, int lda, 1235 const DeviceMemory<std::complex<float>> &b, int ldb, 1236 float beta, DeviceMemory<std::complex<float>> *c, 1237 int ldc) = 0; 1238 virtual bool DoBlasHer2k(Stream *stream, blas::UpperLower uplo, 1239 blas::Transpose trans, uint64 n, uint64 k, 1240 std::complex<double> alpha, 1241 const DeviceMemory<std::complex<double>> &a, int lda, 1242 const DeviceMemory<std::complex<double>> &b, int ldb, 1243 double beta, DeviceMemory<std::complex<double>> *c, 1244 int ldc) = 0; 1245 1246 // Computes a matrix-matrix product where one input matrix is symmetric. 1247 // 1248 // c <- alpha * a * b + beta * c, 1249 // or 1250 // c <- alpha * b * a + beta * c, 1251 // 1252 // alpha and beta are scalars; a is a symmetric matrix; b and c are m-by-n 1253 // matrices. 1254 virtual bool DoBlasSymm(Stream *stream, blas::Side side, 1255 blas::UpperLower uplo, uint64 m, uint64 n, 1256 float alpha, const DeviceMemory<float> &a, int lda, 1257 const DeviceMemory<float> &b, int ldb, float beta, 1258 DeviceMemory<float> *c, int ldc) = 0; 1259 virtual bool DoBlasSymm(Stream *stream, blas::Side side, 1260 blas::UpperLower uplo, uint64 m, uint64 n, 1261 double alpha, const DeviceMemory<double> &a, int lda, 1262 const DeviceMemory<double> &b, int ldb, double beta, 1263 DeviceMemory<double> *c, int ldc) = 0; 1264 virtual bool DoBlasSymm(Stream *stream, blas::Side side, 1265 blas::UpperLower uplo, uint64 m, uint64 n, 1266 std::complex<float> alpha, 1267 const DeviceMemory<std::complex<float>> &a, int lda, 1268 const DeviceMemory<std::complex<float>> &b, int ldb, 1269 std::complex<float> beta, 1270 DeviceMemory<std::complex<float>> *c, int ldc) = 0; 1271 virtual bool DoBlasSymm(Stream *stream, blas::Side side, 1272 blas::UpperLower uplo, uint64 m, uint64 n, 1273 std::complex<double> alpha, 1274 const DeviceMemory<std::complex<double>> &a, int lda, 1275 const DeviceMemory<std::complex<double>> &b, int ldb, 1276 std::complex<double> beta, 1277 DeviceMemory<std::complex<double>> *c, int ldc) = 0; 1278 1279 // Performs a symmetric rank-k update. 1280 // 1281 // c <- alpha * a * a' + beta * c, 1282 // or 1283 // c <- alpha * a' * a + beta * c, 1284 // 1285 // alpha and beta are scalars; c is a n-by-n symmetric matrix; a is an n-by-k 1286 // matrix in the first case and a k-by-n matrix in the second case. 1287 virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, 1288 blas::Transpose trans, uint64 n, uint64 k, 1289 float alpha, const DeviceMemory<float> &a, int lda, 1290 float beta, DeviceMemory<float> *c, int ldc) = 0; 1291 virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, 1292 blas::Transpose trans, uint64 n, uint64 k, 1293 double alpha, const DeviceMemory<double> &a, int lda, 1294 double beta, DeviceMemory<double> *c, int ldc) = 0; 1295 virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, 1296 blas::Transpose trans, uint64 n, uint64 k, 1297 std::complex<float> alpha, 1298 const DeviceMemory<std::complex<float>> &a, int lda, 1299 std::complex<float> beta, 1300 DeviceMemory<std::complex<float>> *c, int ldc) = 0; 1301 virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, 1302 blas::Transpose trans, uint64 n, uint64 k, 1303 std::complex<double> alpha, 1304 const DeviceMemory<std::complex<double>> &a, int lda, 1305 std::complex<double> beta, 1306 DeviceMemory<std::complex<double>> *c, int ldc) = 0; 1307 1308 // Performs a symmetric rank-2k update. 1309 // 1310 // c <- alpha * a * b' + alpha * b * a' + beta * c, 1311 // or 1312 // c <- alpha * b' * a + alpha * a' * b + beta * c, 1313 // 1314 // alpha and beta are scalars; c is a n-by-n symmetric matrix; a and b are 1315 // n-by-k matrices in the first case and k-by-n matrices in the second case. 1316 virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, 1317 blas::Transpose trans, uint64 n, uint64 k, 1318 float alpha, const DeviceMemory<float> &a, int lda, 1319 const DeviceMemory<float> &b, int ldb, float beta, 1320 DeviceMemory<float> *c, int ldc) = 0; 1321 virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, 1322 blas::Transpose trans, uint64 n, uint64 k, 1323 double alpha, const DeviceMemory<double> &a, int lda, 1324 const DeviceMemory<double> &b, int ldb, double beta, 1325 DeviceMemory<double> *c, int ldc) = 0; 1326 virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, 1327 blas::Transpose trans, uint64 n, uint64 k, 1328 std::complex<float> alpha, 1329 const DeviceMemory<std::complex<float>> &a, int lda, 1330 const DeviceMemory<std::complex<float>> &b, int ldb, 1331 std::complex<float> beta, 1332 DeviceMemory<std::complex<float>> *c, int ldc) = 0; 1333 virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, 1334 blas::Transpose trans, uint64 n, uint64 k, 1335 std::complex<double> alpha, 1336 const DeviceMemory<std::complex<double>> &a, int lda, 1337 const DeviceMemory<std::complex<double>> &b, int ldb, 1338 std::complex<double> beta, 1339 DeviceMemory<std::complex<double>> *c, int ldc) = 0; 1340 1341 // Computes a matrix-matrix product where one input matrix is triangular. 1342 // 1343 // b <- alpha * op(a) * b, 1344 // or 1345 // b <- alpha * b * op(a) 1346 // 1347 // alpha is a scalar; b is an m-by-n matrix; a is a unit, or non-unit, upper 1348 // or lower triangular matrix; op(a) is one of op(a) = a, or op(a) = a', or 1349 // op(a) = conj(a'). 1350 virtual bool DoBlasTrmm(Stream *stream, blas::Side side, 1351 blas::UpperLower uplo, blas::Transpose transa, 1352 blas::Diagonal diag, uint64 m, uint64 n, float alpha, 1353 const DeviceMemory<float> &a, int lda, 1354 DeviceMemory<float> *b, int ldb) = 0; 1355 virtual bool DoBlasTrmm(Stream *stream, blas::Side side, 1356 blas::UpperLower uplo, blas::Transpose transa, 1357 blas::Diagonal diag, uint64 m, uint64 n, double alpha, 1358 const DeviceMemory<double> &a, int lda, 1359 DeviceMemory<double> *b, int ldb) = 0; 1360 virtual bool DoBlasTrmm(Stream *stream, blas::Side side, 1361 blas::UpperLower uplo, blas::Transpose transa, 1362 blas::Diagonal diag, uint64 m, uint64 n, 1363 std::complex<float> alpha, 1364 const DeviceMemory<std::complex<float>> &a, int lda, 1365 DeviceMemory<std::complex<float>> *b, int ldb) = 0; 1366 virtual bool DoBlasTrmm(Stream *stream, blas::Side side, 1367 blas::UpperLower uplo, blas::Transpose transa, 1368 blas::Diagonal diag, uint64 m, uint64 n, 1369 std::complex<double> alpha, 1370 const DeviceMemory<std::complex<double>> &a, int lda, 1371 DeviceMemory<std::complex<double>> *b, int ldb) = 0; 1372 1373 // Solves a triangular matrix equation. 1374 // 1375 // op(a) * x = alpha * b, 1376 // or 1377 // x * op(a) = alpha * b 1378 // 1379 // alpha is a scalar; x and b are m-by-n matrices; a is a unit, or non-unit, 1380 // upper or lower triangular matrix; op(a) is one of op(a) = a, or op(a) = a', 1381 // or op(a) = conj(a'). 1382 virtual bool DoBlasTrsm(Stream *stream, blas::Side side, 1383 blas::UpperLower uplo, blas::Transpose transa, 1384 blas::Diagonal diag, uint64 m, uint64 n, float alpha, 1385 const DeviceMemory<float> &a, int lda, 1386 DeviceMemory<float> *b, int ldb) = 0; 1387 virtual bool DoBlasTrsm(Stream *stream, blas::Side side, 1388 blas::UpperLower uplo, blas::Transpose transa, 1389 blas::Diagonal diag, uint64 m, uint64 n, double alpha, 1390 const DeviceMemory<double> &a, int lda, 1391 DeviceMemory<double> *b, int ldb) = 0; 1392 virtual bool DoBlasTrsm(Stream *stream, blas::Side side, 1393 blas::UpperLower uplo, blas::Transpose transa, 1394 blas::Diagonal diag, uint64 m, uint64 n, 1395 std::complex<float> alpha, 1396 const DeviceMemory<std::complex<float>> &a, int lda, 1397 DeviceMemory<std::complex<float>> *b, int ldb) = 0; 1398 virtual bool DoBlasTrsm(Stream *stream, blas::Side side, 1399 blas::UpperLower uplo, blas::Transpose transa, 1400 blas::Diagonal diag, uint64 m, uint64 n, 1401 std::complex<double> alpha, 1402 const DeviceMemory<std::complex<double>> &a, int lda, 1403 DeviceMemory<std::complex<double>> *b, int ldb) = 0; 1404 1405 // Creates a backend-specific plan object for a blaslt matmul operation, which 1406 // can then be passed to DoBlasLtMatmul(). When possible, plans should be 1407 // created once and reused for multiple calls to DoBlasLtMatmul(). 1408 virtual port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>> 1409 CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams ¶ms) = 0; 1410 1411 // Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are 1412 // returned in the order of increasing estimated compute time according to an 1413 // internal heuristic. The first returned algorithm can be used as the default 1414 // algorithm if no autotuning is to be performed. 1415 virtual port::StatusOr< 1416 std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>> 1417 GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan, 1418 size_t max_workspace_size, 1419 int max_algorithm_count) = 0; 1420 1421 // Executes a blaslt matmul operation on the stream. If output_profile_result 1422 // is not nullptr, the operation is profiled, error messages are 1423 // suppressed, and output_profile_result->algorithm() is set to 1424 // algorithm->index(). If epilogue was set to kBias or kBiasThenReLU when 1425 // creating the plan, the bias argument here must refer to a valid device 1426 // vector of length equal to the number of rows in matrix c. If epilogue was 1427 // set to any other value then the bias argument here must be null. The bias 1428 // vector is broadcast across the batch dimension. 1429 // Note that the data types of a and b (c and bias) must match the ab_type 1430 // (c_type) with which the plan was created, and the data types of alpha and 1431 // beta must match the data type of c. 1432 virtual bool DoBlasLtMatmul( 1433 Stream *stream, const blas::IBlasLtMatmulPlan *plan, 1434 const HostOrDeviceScalar<void> &alpha, DeviceMemoryBase a, 1435 DeviceMemoryBase b, const HostOrDeviceScalar<void> &beta, 1436 DeviceMemoryBase c, ScratchAllocator *scratch_allocator, 1437 const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias, 1438 blas::ProfileResult *output_profile_result) = 0; 1439 1440 template <typename ABType, typename CType> 1441 bool DoBlasLtMatmul(Stream *stream, const blas::IBlasLtMatmulPlan *plan, 1442 const HostOrDeviceScalar<CType> &alpha, 1443 const DeviceMemory<ABType> &a, 1444 const DeviceMemory<ABType> &b, 1445 const HostOrDeviceScalar<CType> &beta, 1446 DeviceMemory<CType> *c, 1447 ScratchAllocator *scratch_allocator, 1448 const blas::IBlasLtMatmulAlgorithm *algorithm, 1449 const DeviceMemory<CType> &bias = {}, 1450 blas::ProfileResult *output_profile_result = nullptr) { 1451 constexpr blas::DataType ab_type = blas::ToDataType<ABType>::value; 1452 if (ab_type != plan->ab_type()) { 1453 VLOG(2) << "DoBlasLtMatmul returning false because a and b type does " 1454 "not match plan: expected " 1455 << plan->ab_type() << ", got " << ab_type; 1456 return false; 1457 } 1458 constexpr blas::DataType c_type = blas::ToDataType<CType>::value; 1459 if (c_type != plan->c_type()) { 1460 VLOG(2) << "DoBlasLtMatmul returning false because c type does " 1461 "not match plan: expected " 1462 << plan->c_type() << ", got " << c_type; 1463 return false; 1464 } 1465 return DoBlasLtMatmul(stream, plan, alpha, a, b, beta, *c, 1466 scratch_allocator, algorithm, bias, 1467 output_profile_result); 1468 } 1469 1470 virtual port::Status GetVersion(std::string *version) = 0; 1471 1472 protected: 1473 BlasSupport() {} 1474 1475 private: 1476 SE_DISALLOW_COPY_AND_ASSIGN(BlasSupport); 1477 }; 1478 1479 // Macro used to quickly declare overrides for abstract virtuals in the 1480 // BlasSupport base class. 1481 #define TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES \ 1482 bool DoBlasAsum(Stream *stream, uint64 elem_count, \ 1483 const DeviceMemory<float> &x, int incx, \ 1484 DeviceMemory<float> *result) override; \ 1485 bool DoBlasAsum(Stream *stream, uint64 elem_count, \ 1486 const DeviceMemory<double> &x, int incx, \ 1487 DeviceMemory<double> *result) override; \ 1488 bool DoBlasAsum(Stream *stream, uint64 elem_count, \ 1489 const DeviceMemory<std::complex<float>> &x, int incx, \ 1490 DeviceMemory<float> *result) override; \ 1491 bool DoBlasAsum(Stream *stream, uint64 elem_count, \ 1492 const DeviceMemory<std::complex<double>> &x, int incx, \ 1493 DeviceMemory<double> *result) override; \ 1494 bool DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha, \ 1495 const DeviceMemory<float> &x, int incx, \ 1496 DeviceMemory<float> *y, int incy) override; \ 1497 bool DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha, \ 1498 const DeviceMemory<double> &x, int incx, \ 1499 DeviceMemory<double> *y, int incy) override; \ 1500 bool DoBlasAxpy(Stream *stream, uint64 elem_count, \ 1501 std::complex<float> alpha, \ 1502 const DeviceMemory<std::complex<float>> &x, int incx, \ 1503 DeviceMemory<std::complex<float>> *y, int incy) override; \ 1504 bool DoBlasAxpy(Stream *stream, uint64 elem_count, \ 1505 std::complex<double> alpha, \ 1506 const DeviceMemory<std::complex<double>> &x, int incx, \ 1507 DeviceMemory<std::complex<double>> *y, int incy) override; \ 1508 bool DoBlasCopy(Stream *stream, uint64 elem_count, \ 1509 const DeviceMemory<float> &x, int incx, \ 1510 DeviceMemory<float> *y, int incy) override; \ 1511 bool DoBlasCopy(Stream *stream, uint64 elem_count, \ 1512 const DeviceMemory<double> &x, int incx, \ 1513 DeviceMemory<double> *y, int incy) override; \ 1514 bool DoBlasCopy(Stream *stream, uint64 elem_count, \ 1515 const DeviceMemory<std::complex<float>> &x, int incx, \ 1516 DeviceMemory<std::complex<float>> *y, int incy) override; \ 1517 bool DoBlasCopy(Stream *stream, uint64 elem_count, \ 1518 const DeviceMemory<std::complex<double>> &x, int incx, \ 1519 DeviceMemory<std::complex<double>> *y, int incy) override; \ 1520 bool DoBlasDot(Stream *stream, uint64 elem_count, \ 1521 const DeviceMemory<float> &x, int incx, \ 1522 const DeviceMemory<float> &y, int incy, \ 1523 DeviceMemory<float> *result) override; \ 1524 bool DoBlasDot(Stream *stream, uint64 elem_count, \ 1525 const DeviceMemory<double> &x, int incx, \ 1526 const DeviceMemory<double> &y, int incy, \ 1527 DeviceMemory<double> *result) override; \ 1528 bool DoBlasDotc(Stream *stream, uint64 elem_count, \ 1529 const DeviceMemory<std::complex<float>> &x, int incx, \ 1530 const DeviceMemory<std::complex<float>> &y, int incy, \ 1531 DeviceMemory<std::complex<float>> *result) override; \ 1532 bool DoBlasDotc(Stream *stream, uint64 elem_count, \ 1533 const DeviceMemory<std::complex<double>> &x, int incx, \ 1534 const DeviceMemory<std::complex<double>> &y, int incy, \ 1535 DeviceMemory<std::complex<double>> *result) override; \ 1536 bool DoBlasDotu(Stream *stream, uint64 elem_count, \ 1537 const DeviceMemory<std::complex<float>> &x, int incx, \ 1538 const DeviceMemory<std::complex<float>> &y, int incy, \ 1539 DeviceMemory<std::complex<float>> *result) override; \ 1540 bool DoBlasDotu(Stream *stream, uint64 elem_count, \ 1541 const DeviceMemory<std::complex<double>> &x, int incx, \ 1542 const DeviceMemory<std::complex<double>> &y, int incy, \ 1543 DeviceMemory<std::complex<double>> *result) override; \ 1544 bool DoBlasNrm2(Stream *stream, uint64 elem_count, \ 1545 const DeviceMemory<float> &x, int incx, \ 1546 DeviceMemory<float> *result) override; \ 1547 bool DoBlasNrm2(Stream *stream, uint64 elem_count, \ 1548 const DeviceMemory<double> &x, int incx, \ 1549 DeviceMemory<double> *result) override; \ 1550 bool DoBlasNrm2(Stream *stream, uint64 elem_count, \ 1551 const DeviceMemory<std::complex<float>> &x, int incx, \ 1552 DeviceMemory<float> *result) override; \ 1553 bool DoBlasNrm2(Stream *stream, uint64 elem_count, \ 1554 const DeviceMemory<std::complex<double>> &x, int incx, \ 1555 DeviceMemory<double> *result) override; \ 1556 bool DoBlasRot(Stream *stream, uint64 elem_count, DeviceMemory<float> *x, \ 1557 int incx, DeviceMemory<float> *y, int incy, float c, float s) \ 1558 override; \ 1559 bool DoBlasRot(Stream *stream, uint64 elem_count, DeviceMemory<double> *x, \ 1560 int incx, DeviceMemory<double> *y, int incy, double c, \ 1561 double s) override; \ 1562 bool DoBlasRot(Stream *stream, uint64 elem_count, \ 1563 DeviceMemory<std::complex<float>> *x, int incx, \ 1564 DeviceMemory<std::complex<float>> *y, int incy, float c, \ 1565 float s) override; \ 1566 bool DoBlasRot(Stream *stream, uint64 elem_count, \ 1567 DeviceMemory<std::complex<double>> *x, int incx, \ 1568 DeviceMemory<std::complex<double>> *y, int incy, double c, \ 1569 double s) override; \ 1570 bool DoBlasRotg(Stream *stream, DeviceMemory<float> *a, \ 1571 DeviceMemory<float> *b, DeviceMemory<float> *c, \ 1572 DeviceMemory<float> *s) override; \ 1573 bool DoBlasRotg(Stream *stream, DeviceMemory<double> *a, \ 1574 DeviceMemory<double> *b, DeviceMemory<double> *c, \ 1575 DeviceMemory<double> *s) override; \ 1576 bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<float>> *a, \ 1577 DeviceMemory<std::complex<float>> *b, \ 1578 DeviceMemory<float> *c, \ 1579 DeviceMemory<std::complex<float>> *s) override; \ 1580 bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<double>> *a, \ 1581 DeviceMemory<std::complex<double>> *b, \ 1582 DeviceMemory<double> *c, \ 1583 DeviceMemory<std::complex<double>> *s) override; \ 1584 bool DoBlasRotm(Stream *stream, uint64 elem_count, DeviceMemory<float> *x, \ 1585 int incx, DeviceMemory<float> *y, int incy, \ 1586 const DeviceMemory<float> ¶m) override; \ 1587 bool DoBlasRotm(Stream *stream, uint64 elem_count, DeviceMemory<double> *x, \ 1588 int incx, DeviceMemory<double> *y, int incy, \ 1589 const DeviceMemory<double> ¶m) override; \ 1590 bool DoBlasRotmg(Stream *stream, DeviceMemory<float> *d1, \ 1591 DeviceMemory<float> *d2, DeviceMemory<float> *x1, \ 1592 const DeviceMemory<float> &y1, DeviceMemory<float> *param) \ 1593 override; \ 1594 bool DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1, \ 1595 DeviceMemory<double> *d2, DeviceMemory<double> *x1, \ 1596 const DeviceMemory<double> &y1, \ 1597 DeviceMemory<double> *param) override; \ 1598 bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha, \ 1599 DeviceMemory<float> *x, int incx) override; \ 1600 bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha, \ 1601 DeviceMemory<double> *x, int incx) override; \ 1602 bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha, \ 1603 DeviceMemory<std::complex<float>> *x, int incx) override; \ 1604 bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha, \ 1605 DeviceMemory<std::complex<double>> *x, int incx) override; \ 1606 bool DoBlasScal(Stream *stream, uint64 elem_count, \ 1607 std::complex<float> alpha, \ 1608 DeviceMemory<std::complex<float>> *x, int incx) override; \ 1609 bool DoBlasScal(Stream *stream, uint64 elem_count, \ 1610 std::complex<double> alpha, \ 1611 DeviceMemory<std::complex<double>> *x, int incx) override; \ 1612 bool DoBlasSwap(Stream *stream, uint64 elem_count, DeviceMemory<float> *x, \ 1613 int incx, DeviceMemory<float> *y, int incy) override; \ 1614 bool DoBlasSwap(Stream *stream, uint64 elem_count, DeviceMemory<double> *x, \ 1615 int incx, DeviceMemory<double> *y, int incy) override; \ 1616 bool DoBlasSwap(Stream *stream, uint64 elem_count, \ 1617 DeviceMemory<std::complex<float>> *x, int incx, \ 1618 DeviceMemory<std::complex<float>> *y, int incy) override; \ 1619 bool DoBlasSwap(Stream *stream, uint64 elem_count, \ 1620 DeviceMemory<std::complex<double>> *x, int incx, \ 1621 DeviceMemory<std::complex<double>> *y, int incy) override; \ 1622 bool DoBlasIamax(Stream *stream, uint64 elem_count, \ 1623 const DeviceMemory<float> &x, int incx, \ 1624 DeviceMemory<int> *result) override; \ 1625 bool DoBlasIamax(Stream *stream, uint64 elem_count, \ 1626 const DeviceMemory<double> &x, int incx, \ 1627 DeviceMemory<int> *result) override; \ 1628 bool DoBlasIamax(Stream *stream, uint64 elem_count, \ 1629 const DeviceMemory<std::complex<float>> &x, int incx, \ 1630 DeviceMemory<int> *result) override; \ 1631 bool DoBlasIamax(Stream *stream, uint64 elem_count, \ 1632 const DeviceMemory<std::complex<double>> &x, int incx, \ 1633 DeviceMemory<int> *result) override; \ 1634 bool DoBlasIamin(Stream *stream, uint64 elem_count, \ 1635 const DeviceMemory<float> &x, int incx, \ 1636 DeviceMemory<int> *result) override; \ 1637 bool DoBlasIamin(Stream *stream, uint64 elem_count, \ 1638 const DeviceMemory<double> &x, int incx, \ 1639 DeviceMemory<int> *result) override; \ 1640 bool DoBlasIamin(Stream *stream, uint64 elem_count, \ 1641 const DeviceMemory<std::complex<float>> &x, int incx, \ 1642 DeviceMemory<int> *result) override; \ 1643 bool DoBlasIamin(Stream *stream, uint64 elem_count, \ 1644 const DeviceMemory<std::complex<double>> &x, int incx, \ 1645 DeviceMemory<int> *result) override; \ 1646 bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1647 uint64 kl, uint64 ku, float alpha, \ 1648 const DeviceMemory<float> &a, int lda, \ 1649 const DeviceMemory<float> &x, int incx, float beta, \ 1650 DeviceMemory<float> *y, int incy) override; \ 1651 bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1652 uint64 kl, uint64 ku, double alpha, \ 1653 const DeviceMemory<double> &a, int lda, \ 1654 const DeviceMemory<double> &x, int incx, double beta, \ 1655 DeviceMemory<double> *y, int incy) override; \ 1656 bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1657 uint64 kl, uint64 ku, std::complex<float> alpha, \ 1658 const DeviceMemory<std::complex<float>> &a, int lda, \ 1659 const DeviceMemory<std::complex<float>> &x, int incx, \ 1660 std::complex<float> beta, \ 1661 DeviceMemory<std::complex<float>> *y, int incy) override; \ 1662 bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1663 uint64 kl, uint64 ku, std::complex<double> alpha, \ 1664 const DeviceMemory<std::complex<double>> &a, int lda, \ 1665 const DeviceMemory<std::complex<double>> &x, int incx, \ 1666 std::complex<double> beta, \ 1667 DeviceMemory<std::complex<double>> *y, int incy) override; \ 1668 bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1669 float alpha, const DeviceMemory<float> &a, int lda, \ 1670 const DeviceMemory<float> &x, int incx, float beta, \ 1671 DeviceMemory<float> *y, int incy) override; \ 1672 bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1673 double alpha, const DeviceMemory<double> &a, int lda, \ 1674 const DeviceMemory<double> &x, int incx, double beta, \ 1675 DeviceMemory<double> *y, int incy) override; \ 1676 bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1677 std::complex<float> alpha, \ 1678 const DeviceMemory<std::complex<float>> &a, int lda, \ 1679 const DeviceMemory<std::complex<float>> &x, int incx, \ 1680 std::complex<float> beta, \ 1681 DeviceMemory<std::complex<float>> *y, int incy) override; \ 1682 bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1683 std::complex<double> alpha, \ 1684 const DeviceMemory<std::complex<double>> &a, int lda, \ 1685 const DeviceMemory<std::complex<double>> &x, int incx, \ 1686 std::complex<double> beta, \ 1687 DeviceMemory<std::complex<double>> *y, int incy) override; \ 1688 bool DoBlasGemvWithProfiling( \ 1689 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha, \ 1690 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x, \ 1691 int incx, float beta, DeviceMemory<float> *y, int incy, \ 1692 blas::ProfileResult *output_profile_result) override; \ 1693 bool DoBlasGemvWithProfiling( \ 1694 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha, \ 1695 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x, \ 1696 int incx, double beta, DeviceMemory<double> *y, int incy, \ 1697 blas::ProfileResult *output_profile_result) override; \ 1698 bool DoBlasGemvWithProfiling( \ 1699 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1700 std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a, \ 1701 int lda, const DeviceMemory<std::complex<float>> &x, int incx, \ 1702 std::complex<float> beta, DeviceMemory<std::complex<float>> *y, \ 1703 int incy, blas::ProfileResult *output_profile_result) override; \ 1704 bool DoBlasGemvWithProfiling( \ 1705 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1706 std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a, \ 1707 int lda, const DeviceMemory<std::complex<double>> &x, int incx, \ 1708 std::complex<double> beta, DeviceMemory<std::complex<double>> *y, \ 1709 int incy, blas::ProfileResult *output_profile_result) override; \ 1710 bool DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha, \ 1711 const DeviceMemory<float> &x, int incx, \ 1712 const DeviceMemory<float> &y, int incy, \ 1713 DeviceMemory<float> *a, int lda) override; \ 1714 bool DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha, \ 1715 const DeviceMemory<double> &x, int incx, \ 1716 const DeviceMemory<double> &y, int incy, \ 1717 DeviceMemory<double> *a, int lda) override; \ 1718 bool DoBlasGerc(Stream *stream, uint64 m, uint64 n, \ 1719 std::complex<float> alpha, \ 1720 const DeviceMemory<std::complex<float>> &x, int incx, \ 1721 const DeviceMemory<std::complex<float>> &y, int incy, \ 1722 DeviceMemory<std::complex<float>> *a, int lda) override; \ 1723 bool DoBlasGerc(Stream *stream, uint64 m, uint64 n, \ 1724 std::complex<double> alpha, \ 1725 const DeviceMemory<std::complex<double>> &x, int incx, \ 1726 const DeviceMemory<std::complex<double>> &y, int incy, \ 1727 DeviceMemory<std::complex<double>> *a, int lda) override; \ 1728 bool DoBlasGeru(Stream *stream, uint64 m, uint64 n, \ 1729 std::complex<float> alpha, \ 1730 const DeviceMemory<std::complex<float>> &x, int incx, \ 1731 const DeviceMemory<std::complex<float>> &y, int incy, \ 1732 DeviceMemory<std::complex<float>> *a, int lda) override; \ 1733 bool DoBlasGeru(Stream *stream, uint64 m, uint64 n, \ 1734 std::complex<double> alpha, \ 1735 const DeviceMemory<std::complex<double>> &x, int incx, \ 1736 const DeviceMemory<std::complex<double>> &y, int incy, \ 1737 DeviceMemory<std::complex<double>> *a, int lda) override; \ 1738 bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k, \ 1739 std::complex<float> alpha, \ 1740 const DeviceMemory<std::complex<float>> &a, int lda, \ 1741 const DeviceMemory<std::complex<float>> &x, int incx, \ 1742 std::complex<float> beta, \ 1743 DeviceMemory<std::complex<float>> *y, int incy) override; \ 1744 bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k, \ 1745 std::complex<double> alpha, \ 1746 const DeviceMemory<std::complex<double>> &a, int lda, \ 1747 const DeviceMemory<std::complex<double>> &x, int incx, \ 1748 std::complex<double> beta, \ 1749 DeviceMemory<std::complex<double>> *y, int incy) override; \ 1750 bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1751 std::complex<float> alpha, \ 1752 const DeviceMemory<std::complex<float>> &a, int lda, \ 1753 const DeviceMemory<std::complex<float>> &x, int incx, \ 1754 std::complex<float> beta, \ 1755 DeviceMemory<std::complex<float>> *y, int incy) override; \ 1756 bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1757 std::complex<double> alpha, \ 1758 const DeviceMemory<std::complex<double>> &a, int lda, \ 1759 const DeviceMemory<std::complex<double>> &x, int incx, \ 1760 std::complex<double> beta, \ 1761 DeviceMemory<std::complex<double>> *y, int incy) override; \ 1762 bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \ 1763 const DeviceMemory<std::complex<float>> &x, int incx, \ 1764 DeviceMemory<std::complex<float>> *a, int lda) override; \ 1765 bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1766 double alpha, const DeviceMemory<std::complex<double>> &x, \ 1767 int incx, DeviceMemory<std::complex<double>> *a, int lda) \ 1768 override; \ 1769 bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1770 std::complex<float> alpha, \ 1771 const DeviceMemory<std::complex<float>> &x, int incx, \ 1772 const DeviceMemory<std::complex<float>> &y, int incy, \ 1773 DeviceMemory<std::complex<float>> *a, int lda) override; \ 1774 bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1775 std::complex<double> alpha, \ 1776 const DeviceMemory<std::complex<double>> &x, int incx, \ 1777 const DeviceMemory<std::complex<double>> &y, int incy, \ 1778 DeviceMemory<std::complex<double>> *a, int lda) override; \ 1779 bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1780 std::complex<float> alpha, \ 1781 const DeviceMemory<std::complex<float>> &ap, \ 1782 const DeviceMemory<std::complex<float>> &x, int incx, \ 1783 std::complex<float> beta, \ 1784 DeviceMemory<std::complex<float>> *y, int incy) override; \ 1785 bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1786 std::complex<double> alpha, \ 1787 const DeviceMemory<std::complex<double>> &ap, \ 1788 const DeviceMemory<std::complex<double>> &x, int incx, \ 1789 std::complex<double> beta, \ 1790 DeviceMemory<std::complex<double>> *y, int incy) override; \ 1791 bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \ 1792 const DeviceMemory<std::complex<float>> &x, int incx, \ 1793 DeviceMemory<std::complex<float>> *ap) override; \ 1794 bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1795 double alpha, const DeviceMemory<std::complex<double>> &x, \ 1796 int incx, DeviceMemory<std::complex<double>> *ap) override; \ 1797 bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1798 std::complex<float> alpha, \ 1799 const DeviceMemory<std::complex<float>> &x, int incx, \ 1800 const DeviceMemory<std::complex<float>> &y, int incy, \ 1801 DeviceMemory<std::complex<float>> *ap) override; \ 1802 bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1803 std::complex<double> alpha, \ 1804 const DeviceMemory<std::complex<double>> &x, int incx, \ 1805 const DeviceMemory<std::complex<double>> &y, int incy, \ 1806 DeviceMemory<std::complex<double>> *ap) override; \ 1807 bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k, \ 1808 float alpha, const DeviceMemory<float> &a, int lda, \ 1809 const DeviceMemory<float> &x, int incx, float beta, \ 1810 DeviceMemory<float> *y, int incy) override; \ 1811 bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k, \ 1812 double alpha, const DeviceMemory<double> &a, int lda, \ 1813 const DeviceMemory<double> &x, int incx, double beta, \ 1814 DeviceMemory<double> *y, int incy) override; \ 1815 bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1816 float alpha, const DeviceMemory<float> &ap, \ 1817 const DeviceMemory<float> &x, int incx, float beta, \ 1818 DeviceMemory<float> *y, int incy) override; \ 1819 bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1820 double alpha, const DeviceMemory<double> &ap, \ 1821 const DeviceMemory<double> &x, int incx, double beta, \ 1822 DeviceMemory<double> *y, int incy) override; \ 1823 bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \ 1824 const DeviceMemory<float> &x, int incx, \ 1825 DeviceMemory<float> *ap) override; \ 1826 bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1827 double alpha, const DeviceMemory<double> &x, int incx, \ 1828 DeviceMemory<double> *ap) override; \ 1829 bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1830 float alpha, const DeviceMemory<float> &x, int incx, \ 1831 const DeviceMemory<float> &y, int incy, \ 1832 DeviceMemory<float> *ap) override; \ 1833 bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1834 double alpha, const DeviceMemory<double> &x, int incx, \ 1835 const DeviceMemory<double> &y, int incy, \ 1836 DeviceMemory<double> *ap) override; \ 1837 bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1838 float alpha, const DeviceMemory<float> &a, int lda, \ 1839 const DeviceMemory<float> &x, int incx, float beta, \ 1840 DeviceMemory<float> *y, int incy) override; \ 1841 bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1842 double alpha, const DeviceMemory<double> &a, int lda, \ 1843 const DeviceMemory<double> &x, int incx, double beta, \ 1844 DeviceMemory<double> *y, int incy) override; \ 1845 bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \ 1846 const DeviceMemory<float> &x, int incx, \ 1847 DeviceMemory<float> *a, int lda) override; \ 1848 bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1849 double alpha, const DeviceMemory<double> &x, int incx, \ 1850 DeviceMemory<double> *a, int lda) override; \ 1851 bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1852 float alpha, const DeviceMemory<float> &x, int incx, \ 1853 const DeviceMemory<float> &y, int incy, \ 1854 DeviceMemory<float> *a, int lda) override; \ 1855 bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1856 double alpha, const DeviceMemory<double> &x, int incx, \ 1857 const DeviceMemory<double> &y, int incy, \ 1858 DeviceMemory<double> *a, int lda) override; \ 1859 bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, \ 1860 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1861 uint64 k, const DeviceMemory<float> &a, int lda, \ 1862 DeviceMemory<float> *x, int incx) override; \ 1863 bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, \ 1864 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1865 uint64 k, const DeviceMemory<double> &a, int lda, \ 1866 DeviceMemory<double> *x, int incx) override; \ 1867 bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, \ 1868 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1869 uint64 k, const DeviceMemory<std::complex<float>> &a, \ 1870 int lda, DeviceMemory<std::complex<float>> *x, int incx) \ 1871 override; \ 1872 bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, \ 1873 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1874 uint64 k, const DeviceMemory<std::complex<double>> &a, \ 1875 int lda, DeviceMemory<std::complex<double>> *x, int incx) \ 1876 override; \ 1877 bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, \ 1878 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1879 uint64 k, const DeviceMemory<float> &a, int lda, \ 1880 DeviceMemory<float> *x, int incx) override; \ 1881 bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, \ 1882 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1883 uint64 k, const DeviceMemory<double> &a, int lda, \ 1884 DeviceMemory<double> *x, int incx) override; \ 1885 bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, \ 1886 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1887 uint64 k, const DeviceMemory<std::complex<float>> &a, \ 1888 int lda, DeviceMemory<std::complex<float>> *x, int incx) \ 1889 override; \ 1890 bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, \ 1891 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1892 uint64 k, const DeviceMemory<std::complex<double>> &a, \ 1893 int lda, DeviceMemory<std::complex<double>> *x, int incx) \ 1894 override; \ 1895 bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, \ 1896 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1897 const DeviceMemory<float> &ap, DeviceMemory<float> *x, \ 1898 int incx) override; \ 1899 bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, \ 1900 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1901 const DeviceMemory<double> &ap, DeviceMemory<double> *x, \ 1902 int incx) override; \ 1903 bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, \ 1904 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1905 const DeviceMemory<std::complex<float>> &ap, \ 1906 DeviceMemory<std::complex<float>> *x, int incx) override; \ 1907 bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, \ 1908 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1909 const DeviceMemory<std::complex<double>> &ap, \ 1910 DeviceMemory<std::complex<double>> *x, int incx) override; \ 1911 bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, \ 1912 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1913 const DeviceMemory<float> &ap, DeviceMemory<float> *x, \ 1914 int incx) override; \ 1915 bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, \ 1916 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1917 const DeviceMemory<double> &ap, DeviceMemory<double> *x, \ 1918 int incx) override; \ 1919 bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, \ 1920 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1921 const DeviceMemory<std::complex<float>> &ap, \ 1922 DeviceMemory<std::complex<float>> *x, int incx) override; \ 1923 bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, \ 1924 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1925 const DeviceMemory<std::complex<double>> &ap, \ 1926 DeviceMemory<std::complex<double>> *x, int incx) override; \ 1927 bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, \ 1928 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1929 const DeviceMemory<float> &a, int lda, \ 1930 DeviceMemory<float> *x, int incx) override; \ 1931 bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, \ 1932 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1933 const DeviceMemory<double> &a, int lda, \ 1934 DeviceMemory<double> *x, int incx) override; \ 1935 bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, \ 1936 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1937 const DeviceMemory<std::complex<float>> &a, int lda, \ 1938 DeviceMemory<std::complex<float>> *x, int incx) override; \ 1939 bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, \ 1940 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1941 const DeviceMemory<std::complex<double>> &a, int lda, \ 1942 DeviceMemory<std::complex<double>> *x, int incx) override; \ 1943 bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, \ 1944 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1945 const DeviceMemory<float> &a, int lda, \ 1946 DeviceMemory<float> *x, int incx) override; \ 1947 bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, \ 1948 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1949 const DeviceMemory<double> &a, int lda, \ 1950 DeviceMemory<double> *x, int incx) override; \ 1951 bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, \ 1952 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1953 const DeviceMemory<std::complex<float>> &a, int lda, \ 1954 DeviceMemory<std::complex<float>> *x, int incx) override; \ 1955 bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, \ 1956 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1957 const DeviceMemory<std::complex<double>> &a, int lda, \ 1958 DeviceMemory<std::complex<double>> *x, int incx) override; \ 1959 port::Status DoBlasGemm( \ 1960 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 1961 uint64 m, uint64 n, uint64 k, blas::DataType dtype, const void *alpha, \ 1962 const DeviceMemoryBase &a, int lda, const DeviceMemoryBase &b, int ldb, \ 1963 const void *beta, DeviceMemoryBase *c, int ldc) override; \ 1964 bool DoBlasGemmWithProfiling( \ 1965 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 1966 uint64 m, uint64 n, uint64 k, float alpha, \ 1967 const DeviceMemory<Eigen::half> &a, int lda, \ 1968 const DeviceMemory<Eigen::half> &b, int ldb, float beta, \ 1969 DeviceMemory<Eigen::half> *c, int ldc, \ 1970 blas::ProfileResult *output_profile_result) override; \ 1971 bool DoBlasGemmWithProfiling( \ 1972 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 1973 uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, \ 1974 int lda, const DeviceMemory<float> &b, int ldb, float beta, \ 1975 DeviceMemory<float> *c, int ldc, \ 1976 blas::ProfileResult *output_profile_result) override; \ 1977 bool DoBlasGemmWithProfiling( \ 1978 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 1979 uint64 m, uint64 n, uint64 k, double alpha, \ 1980 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b, \ 1981 int ldb, double beta, DeviceMemory<double> *c, int ldc, \ 1982 blas::ProfileResult *output_profile_result) override; \ 1983 bool DoBlasGemmWithProfiling( \ 1984 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 1985 uint64 m, uint64 n, uint64 k, std::complex<float> alpha, \ 1986 const DeviceMemory<std::complex<float>> &a, int lda, \ 1987 const DeviceMemory<std::complex<float>> &b, int ldb, \ 1988 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, \ 1989 blas::ProfileResult *output_profile_result) override; \ 1990 bool DoBlasGemmWithProfiling( \ 1991 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 1992 uint64 m, uint64 n, uint64 k, std::complex<double> alpha, \ 1993 const DeviceMemory<std::complex<double>> &a, int lda, \ 1994 const DeviceMemory<std::complex<double>> &b, int ldb, \ 1995 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, \ 1996 int ldc, blas::ProfileResult *output_profile_result) override; \ 1997 bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms) \ 1998 override; \ 1999 port::Status DoBlasGemmWithAlgorithm( \ 2000 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2001 uint64 m, uint64 n, uint64 k, const void *alpha, \ 2002 const DeviceMemoryBase &a, blas::DataType type_a, int lda, \ 2003 const DeviceMemoryBase &b, blas::DataType type_b, int ldb, \ 2004 const void *beta, DeviceMemoryBase *c, blas::DataType type_c, int ldc, \ 2005 blas::ComputationType computation_type, blas::AlgorithmType algorithm, \ 2006 blas::ProfileResult *output_profile_result) override; \ 2007 bool DoBlasGemmBatched( \ 2008 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2009 uint64 m, uint64 n, uint64 k, float alpha, \ 2010 const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda, \ 2011 const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, \ 2012 float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, \ 2013 int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \ 2014 bool DoBlasGemmBatched( \ 2015 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2016 uint64 m, uint64 n, uint64 k, float alpha, \ 2017 const port::ArraySlice<DeviceMemory<float> *> &a, int lda, \ 2018 const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, float beta, \ 2019 const port::ArraySlice<DeviceMemory<float> *> &c, int ldc, \ 2020 int batch_count, ScratchAllocator *scratch_allocator) override; \ 2021 bool DoBlasGemmBatched( \ 2022 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2023 uint64 m, uint64 n, uint64 k, double alpha, \ 2024 const port::ArraySlice<DeviceMemory<double> *> &a, int lda, \ 2025 const port::ArraySlice<DeviceMemory<double> *> &b, int ldb, double beta, \ 2026 const port::ArraySlice<DeviceMemory<double> *> &c, int ldc, \ 2027 int batch_count, ScratchAllocator *scratch_allocator) override; \ 2028 bool DoBlasGemmBatched( \ 2029 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2030 uint64 m, uint64 n, uint64 k, std::complex<float> alpha, \ 2031 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda, \ 2032 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb, \ 2033 std::complex<float> beta, \ 2034 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc, \ 2035 int batch_count, ScratchAllocator *scratch_allocator) override; \ 2036 bool DoBlasGemmBatched( \ 2037 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2038 uint64 m, uint64 n, uint64 k, std::complex<double> alpha, \ 2039 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, \ 2040 int lda, \ 2041 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, \ 2042 int ldb, std::complex<double> beta, \ 2043 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, \ 2044 int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \ 2045 port::Status DoBlasGemmStridedBatched( \ 2046 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2047 uint64 m, uint64 n, uint64 k, blas::DataType dtype, const void *alpha, \ 2048 const DeviceMemoryBase &a, int lda, int64 stride_a, \ 2049 const DeviceMemoryBase &b, int ldb, int64 stride_b, const void *beta, \ 2050 DeviceMemoryBase *c, int ldc, int64 stride_c, int batch_count); \ 2051 port::Status DoBlasGemmStridedBatchedWithAlgorithm( \ 2052 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2053 uint64 m, uint64 n, uint64 k, const void *alpha, \ 2054 const DeviceMemoryBase &a, blas::DataType type_a, int lda, \ 2055 int64 stride_a, const DeviceMemoryBase &b, blas::DataType type_b, \ 2056 int ldb, int64 stride_b, const void *beta, DeviceMemoryBase *c, \ 2057 blas::DataType type_c, int ldc, int64 stride_c, int batch_count, \ 2058 blas::ComputationType computation_type, blas::AlgorithmType algorithm, \ 2059 blas::ProfileResult *output_profile_result) override; \ 2060 bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2061 uint64 m, uint64 n, std::complex<float> alpha, \ 2062 const DeviceMemory<std::complex<float>> &a, int lda, \ 2063 const DeviceMemory<std::complex<float>> &b, int ldb, \ 2064 std::complex<float> beta, \ 2065 DeviceMemory<std::complex<float>> *c, int ldc) override; \ 2066 bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2067 uint64 m, uint64 n, std::complex<double> alpha, \ 2068 const DeviceMemory<std::complex<double>> &a, int lda, \ 2069 const DeviceMemory<std::complex<double>> &b, int ldb, \ 2070 std::complex<double> beta, \ 2071 DeviceMemory<std::complex<double>> *c, int ldc) override; \ 2072 bool DoBlasHerk(Stream *stream, blas::UpperLower uplo, \ 2073 blas::Transpose trans, uint64 n, uint64 k, float alpha, \ 2074 const DeviceMemory<std::complex<float>> &a, int lda, \ 2075 float beta, DeviceMemory<std::complex<float>> *c, int ldc) \ 2076 override; \ 2077 bool DoBlasHerk(Stream *stream, blas::UpperLower uplo, \ 2078 blas::Transpose trans, uint64 n, uint64 k, double alpha, \ 2079 const DeviceMemory<std::complex<double>> &a, int lda, \ 2080 double beta, DeviceMemory<std::complex<double>> *c, int ldc) \ 2081 override; \ 2082 bool DoBlasHer2k( \ 2083 Stream *stream, blas::UpperLower uplo, blas::Transpose trans, uint64 n, \ 2084 uint64 k, std::complex<float> alpha, \ 2085 const DeviceMemory<std::complex<float>> &a, int lda, \ 2086 const DeviceMemory<std::complex<float>> &b, int ldb, float beta, \ 2087 DeviceMemory<std::complex<float>> *c, int ldc) override; \ 2088 bool DoBlasHer2k( \ 2089 Stream *stream, blas::UpperLower uplo, blas::Transpose trans, uint64 n, \ 2090 uint64 k, std::complex<double> alpha, \ 2091 const DeviceMemory<std::complex<double>> &a, int lda, \ 2092 const DeviceMemory<std::complex<double>> &b, int ldb, double beta, \ 2093 DeviceMemory<std::complex<double>> *c, int ldc) override; \ 2094 bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2095 uint64 m, uint64 n, float alpha, \ 2096 const DeviceMemory<float> &a, int lda, \ 2097 const DeviceMemory<float> &b, int ldb, float beta, \ 2098 DeviceMemory<float> *c, int ldc) override; \ 2099 bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2100 uint64 m, uint64 n, double alpha, \ 2101 const DeviceMemory<double> &a, int lda, \ 2102 const DeviceMemory<double> &b, int ldb, double beta, \ 2103 DeviceMemory<double> *c, int ldc) override; \ 2104 bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2105 uint64 m, uint64 n, std::complex<float> alpha, \ 2106 const DeviceMemory<std::complex<float>> &a, int lda, \ 2107 const DeviceMemory<std::complex<float>> &b, int ldb, \ 2108 std::complex<float> beta, \ 2109 DeviceMemory<std::complex<float>> *c, int ldc) override; \ 2110 bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2111 uint64 m, uint64 n, std::complex<double> alpha, \ 2112 const DeviceMemory<std::complex<double>> &a, int lda, \ 2113 const DeviceMemory<std::complex<double>> &b, int ldb, \ 2114 std::complex<double> beta, \ 2115 DeviceMemory<std::complex<double>> *c, int ldc) override; \ 2116 bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, \ 2117 blas::Transpose trans, uint64 n, uint64 k, float alpha, \ 2118 const DeviceMemory<float> &a, int lda, float beta, \ 2119 DeviceMemory<float> *c, int ldc) override; \ 2120 bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, \ 2121 blas::Transpose trans, uint64 n, uint64 k, double alpha, \ 2122 const DeviceMemory<double> &a, int lda, double beta, \ 2123 DeviceMemory<double> *c, int ldc) override; \ 2124 bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, \ 2125 blas::Transpose trans, uint64 n, uint64 k, \ 2126 std::complex<float> alpha, \ 2127 const DeviceMemory<std::complex<float>> &a, int lda, \ 2128 std::complex<float> beta, \ 2129 DeviceMemory<std::complex<float>> *c, int ldc) override; \ 2130 bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, \ 2131 blas::Transpose trans, uint64 n, uint64 k, \ 2132 std::complex<double> alpha, \ 2133 const DeviceMemory<std::complex<double>> &a, int lda, \ 2134 std::complex<double> beta, \ 2135 DeviceMemory<std::complex<double>> *c, int ldc) override; \ 2136 bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, \ 2137 blas::Transpose trans, uint64 n, uint64 k, float alpha, \ 2138 const DeviceMemory<float> &a, int lda, \ 2139 const DeviceMemory<float> &b, int ldb, float beta, \ 2140 DeviceMemory<float> *c, int ldc) override; \ 2141 bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, \ 2142 blas::Transpose trans, uint64 n, uint64 k, double alpha, \ 2143 const DeviceMemory<double> &a, int lda, \ 2144 const DeviceMemory<double> &b, int ldb, double beta, \ 2145 DeviceMemory<double> *c, int ldc) override; \ 2146 bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, \ 2147 blas::Transpose trans, uint64 n, uint64 k, \ 2148 std::complex<float> alpha, \ 2149 const DeviceMemory<std::complex<float>> &a, int lda, \ 2150 const DeviceMemory<std::complex<float>> &b, int ldb, \ 2151 std::complex<float> beta, \ 2152 DeviceMemory<std::complex<float>> *c, int ldc) override; \ 2153 bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, \ 2154 blas::Transpose trans, uint64 n, uint64 k, \ 2155 std::complex<double> alpha, \ 2156 const DeviceMemory<std::complex<double>> &a, int lda, \ 2157 const DeviceMemory<std::complex<double>> &b, int ldb, \ 2158 std::complex<double> beta, \ 2159 DeviceMemory<std::complex<double>> *c, int ldc) override; \ 2160 bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2161 blas::Transpose transa, blas::Diagonal diag, uint64 m, \ 2162 uint64 n, float alpha, const DeviceMemory<float> &a, \ 2163 int lda, DeviceMemory<float> *b, int ldb) override; \ 2164 bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2165 blas::Transpose transa, blas::Diagonal diag, uint64 m, \ 2166 uint64 n, double alpha, const DeviceMemory<double> &a, \ 2167 int lda, DeviceMemory<double> *b, int ldb) override; \ 2168 bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2169 blas::Transpose transa, blas::Diagonal diag, uint64 m, \ 2170 uint64 n, std::complex<float> alpha, \ 2171 const DeviceMemory<std::complex<float>> &a, int lda, \ 2172 DeviceMemory<std::complex<float>> *b, int ldb) override; \ 2173 bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2174 blas::Transpose transa, blas::Diagonal diag, uint64 m, \ 2175 uint64 n, std::complex<double> alpha, \ 2176 const DeviceMemory<std::complex<double>> &a, int lda, \ 2177 DeviceMemory<std::complex<double>> *b, int ldb) override; \ 2178 bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2179 blas::Transpose transa, blas::Diagonal diag, uint64 m, \ 2180 uint64 n, float alpha, const DeviceMemory<float> &a, \ 2181 int lda, DeviceMemory<float> *b, int ldb) override; \ 2182 bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2183 blas::Transpose transa, blas::Diagonal diag, uint64 m, \ 2184 uint64 n, double alpha, const DeviceMemory<double> &a, \ 2185 int lda, DeviceMemory<double> *b, int ldb) override; \ 2186 bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2187 blas::Transpose transa, blas::Diagonal diag, uint64 m, \ 2188 uint64 n, std::complex<float> alpha, \ 2189 const DeviceMemory<std::complex<float>> &a, int lda, \ 2190 DeviceMemory<std::complex<float>> *b, int ldb) override; \ 2191 bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2192 blas::Transpose transa, blas::Diagonal diag, uint64 m, \ 2193 uint64 n, std::complex<double> alpha, \ 2194 const DeviceMemory<std::complex<double>> &a, int lda, \ 2195 DeviceMemory<std::complex<double>> *b, int ldb) override; \ 2196 port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>> \ 2197 CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams ¶ms) override; \ 2198 port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>> \ 2199 GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan, \ 2200 size_t max_workspace_size, \ 2201 int max_algorithm_count) override; \ 2202 bool DoBlasLtMatmul( \ 2203 Stream *stream, const blas::IBlasLtMatmulPlan *plan, \ 2204 const HostOrDeviceScalar<void> &alpha, DeviceMemoryBase a, \ 2205 DeviceMemoryBase b, const HostOrDeviceScalar<void> &beta, \ 2206 DeviceMemoryBase c, ScratchAllocator *scratch_allocator, \ 2207 const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias, \ 2208 blas::ProfileResult *output_profile_result) override; \ 2209 port::Status GetVersion(std::string *version) override; 2210 2211 } // namespace blas 2212 } // namespace stream_executor 2213 2214 #endif // TENSORFLOW_STREAM_EXECUTOR_BLAS_H_ 2215