/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ // Exposes the family of BLAS routines as pre-canned high performance calls for // use in conjunction with the StreamExecutor abstraction. // // Note that this interface is optionally supported by platforms; see // StreamExecutor::SupportsBlas() for details. // // This abstraction makes it simple to entrain BLAS operations on GPU data into // a Stream -- users typically will not use this API directly, but will use the // Stream builder methods to entrain these operations "under the hood". For // example: // // DeviceMemory x = stream_exec->AllocateArray(1024); // DeviceMemory y = stream_exec->AllocateArray(1024); // // ... populate x and y ... // Stream stream{stream_exec}; // stream // .Init() // .ThenBlasAxpy(1024, 5.5, x, 1, &y, 1); // SE_CHECK_OK(stream.BlockHostUntilDone()); // // By using stream operations in this manner the user can easily intermix custom // kernel launches (via StreamExecutor::ThenLaunch()) with these pre-canned BLAS // routines. #ifndef TENSORFLOW_STREAM_EXECUTOR_BLAS_H_ #define TENSORFLOW_STREAM_EXECUTOR_BLAS_H_ #include #include #include "tensorflow/stream_executor/dnn.h" // For DataType, ToDataType #include "tensorflow/stream_executor/lib/array_slice.h" #include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/platform/port.h" namespace Eigen { struct half; } // namespace Eigen namespace stream_executor { class Stream; class ScratchAllocator; template class DeviceMemory; template class HostOrDeviceScalar; namespace blas { // Specifies whether the input matrix will be transposed or // transposed+conjugated before any BLAS operations. enum class Transpose { kNoTranspose, kTranspose, kConjugateTranspose }; // Returns a name for t. std::string TransposeString(Transpose t); // Specifies whether the upper or lower triangular part of a // symmetric/Hermitian matrix is used. enum class UpperLower { kUpper, kLower }; // Returns a name for ul. std::string UpperLowerString(UpperLower ul); // Specifies whether a matrix is unit triangular. enum class Diagonal { kUnit, kNonUnit }; // Returns a name for d. std::string DiagonalString(Diagonal d); // Specifies whether a Hermitian matrix appears on the left or right in // operation. enum class Side { kLeft, kRight }; // Returns a name for s. std::string SideString(Side s); // Type with which intermediate computations of a blas routine are performed. // // Some blas calls can perform computations with a type that's different than // the type of their inputs/outputs. This lets you e.g. multiply two matrices // of int8s using float32s to store the matmul's intermediate values. enum class ComputationType { kF16, // 16-bit floating-point kF32, // 32-bit floating-point kF64, // 64-bit floating-point kI32, // 32-bit integer kComplexF32, // Complex number comprised of two f32s. kComplexF64, // Complex number comprised of two f64s. // The below values are only supported for BlasLt routines (both real and // complex). They use float32 for accumulation but round the input mantissas // to a smaller number of bits. kTF32AsF32, // 32-bit floating-point with reduced (>=10-bit) mantissa kBF16AsF32, // 32-bit floating-point with reduced (7-bit) mantissa }; enum class Epilogue { kDefault = 1, // No special postprocessing kReLU = 2, // Apply ReLU func point-wise to the results kBias = 4, // Add broadcasted bias vector to the results kBiasThenReLU = kBias | kReLU, // Apply bias and then ReLU transform }; // Converts a ComputationType to a string. std::string ComputationTypeString(ComputationType ty); std::ostream &operator<<(std::ostream &os, ComputationType ty); using dnn::DataType; using dnn::ToDataType; // Describes the type of pointers for the scaling factors alpha and beta in // blaslt routines. enum class PointerMode { kHost, kDevice, }; // Converts a ComputationType to a string. std::string DataTypeString(DataType ty); std::ostream &operator<<(std::ostream &os, DataType ty); // Opaque identifier for an "algorithm" used by a blas routine. This functions // as a hint to the blas library. typedef int64 AlgorithmType; constexpr AlgorithmType kDefaultAlgorithm = -1; constexpr AlgorithmType kDefaultBlasGemm = -2; constexpr AlgorithmType kDefaultBlasGemv = -3; constexpr AlgorithmType kNoAlgorithm = -4; // blas uses -1 to represent the default algorithm. This happens to match up // with the CUBLAS_GEMM_DFALT constant, so cuda_blas.cc is using static_cast // to convert from AlgorithmType to cublasGemmAlgo_t, and uses a static_assert // to ensure that this assumption does not break. // If another blas implementation uses a different value for the default // algorithm, then it needs to convert kDefaultGemmAlgo to that value // (e.g. via a function called ToWhateverGemmAlgo). constexpr AlgorithmType kDefaultGemmAlgo = -1; // Describes the result of a performance experiment, usually timing the speed of // a particular AlgorithmType. // // If the call we were benchmarking failed (a common occurrence; not all // algorithms are valid for all calls), is_valid() will be false. class ProfileResult { public: bool is_valid() const { return is_valid_; } void set_is_valid(bool val) { is_valid_ = val; } AlgorithmType algorithm() const { return algorithm_; } void set_algorithm(AlgorithmType val) { algorithm_ = val; } float elapsed_time_in_ms() const { return elapsed_time_in_ms_; } void set_elapsed_time_in_ms(float val) { elapsed_time_in_ms_ = val; } private: bool is_valid_ = false; AlgorithmType algorithm_ = kDefaultAlgorithm; float elapsed_time_in_ms_ = std::numeric_limits::max(); }; class AlgorithmConfig { public: AlgorithmConfig() : algorithm_(kDefaultAlgorithm) {} explicit AlgorithmConfig(AlgorithmType algorithm) : algorithm_(algorithm) {} AlgorithmType algorithm() const { return algorithm_; } void set_algorithm(AlgorithmType val) { algorithm_ = val; } bool operator==(const AlgorithmConfig &other) const { return this->algorithm_ == other.algorithm_; } bool operator!=(const AlgorithmConfig &other) const { return !(*this == other); } std::string ToString() const; private: AlgorithmType algorithm_; }; struct IBlasLtMatmulPlan { // Returns the data type of the A and B (input) matrices. virtual DataType ab_type() const = 0; // Returns the data type of the C (input/output) matrix. virtual DataType c_type() const = 0; virtual ~IBlasLtMatmulPlan() {} }; struct IBlasLtMatmulAlgorithm { virtual ~IBlasLtMatmulAlgorithm() {} // Returns the index of the algorithm within the list returned by // GetBlasLtMatmulAlgorithms. virtual AlgorithmType index() const = 0; // Returns the workspace size required by the algorithm in bytes. virtual size_t workspace_size() const = 0; }; // Parameters for the CreateBlasLtMatmulPlan method. struct BlasLtMatmulPlanParams { DataType ab_type; DataType c_type; ComputationType computation_type; PointerMode pointer_mode; Epilogue epilogue; Transpose transa; Transpose transb; uint64 m; uint64 n; uint64 k; int64 lda; int64 ldb; int64 ldc; int batch_count = 1; int64 stride_a = 0; int64 stride_b = 0; int64 stride_c = 0; }; // BLAS support interface -- this can be derived from a GPU executor when the // underlying platform has an BLAS library implementation available. See // StreamExecutor::AsBlas(). // // Thread-hostile: CUDA associates a CUDA-context with a particular thread in // the system. Any operation that a user attempts to perform by enqueueing BLAS // operations on a thread not-associated with the CUDA-context has unknown // behavior at the current time; see b/13176597 class BlasSupport { public: virtual ~BlasSupport() {} // Computes the sum of magnitudes of the vector elements. // result <- |Re x(1)| + |Im x(1)| + |Re x(2)| + |Im x(2)|+ ... + |Re x(n)| // + |Im x(n)|. // Note that Im x(i) = 0 for real types float/double. virtual bool DoBlasAsum(Stream *stream, uint64 elem_count, const DeviceMemory &x, int incx, DeviceMemory *result) = 0; virtual bool DoBlasAsum(Stream *stream, uint64 elem_count, const DeviceMemory &x, int incx, DeviceMemory *result) = 0; virtual bool DoBlasAsum(Stream *stream, uint64 elem_count, const DeviceMemory> &x, int incx, DeviceMemory *result) = 0; virtual bool DoBlasAsum(Stream *stream, uint64 elem_count, const DeviceMemory> &x, int incx, DeviceMemory *result) = 0; // Performs a BLAS y <- ax+y operation. virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha, const DeviceMemory &x, int incx, DeviceMemory *y, int incy) = 0; virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha, const DeviceMemory &x, int incx, DeviceMemory *y, int incy) = 0; virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count, std::complex alpha, const DeviceMemory> &x, int incx, DeviceMemory> *y, int incy) = 0; virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count, std::complex alpha, const DeviceMemory> &x, int incx, DeviceMemory> *y, int incy) = 0; // Copies vector to another vector: y <- x. virtual bool DoBlasCopy(Stream *stream, uint64 elem_count, const DeviceMemory &x, int incx, DeviceMemory *y, int incy) = 0; virtual bool DoBlasCopy(Stream *stream, uint64 elem_count, const DeviceMemory &x, int incx, DeviceMemory *y, int incy) = 0; virtual bool DoBlasCopy(Stream *stream, uint64 elem_count, const DeviceMemory> &x, int incx, DeviceMemory> *y, int incy) = 0; virtual bool DoBlasCopy(Stream *stream, uint64 elem_count, const DeviceMemory> &x, int incx, DeviceMemory> *y, int incy) = 0; // Performs a BLAS dot product result <- x . y. virtual bool DoBlasDot(Stream *stream, uint64 elem_count, const DeviceMemory &x, int incx, const DeviceMemory &y, int incy, DeviceMemory *result) = 0; virtual bool DoBlasDot(Stream *stream, uint64 elem_count, const DeviceMemory &x, int incx, const DeviceMemory &y, int incy, DeviceMemory *result) = 0; // Performs a BLAS dot product result <- conj(x) . y for complex types. virtual bool DoBlasDotc(Stream *stream, uint64 elem_count, const DeviceMemory> &x, int incx, const DeviceMemory> &y, int incy, DeviceMemory> *result) = 0; virtual bool DoBlasDotc(Stream *stream, uint64 elem_count, const DeviceMemory> &x, int incx, const DeviceMemory> &y, int incy, DeviceMemory> *result) = 0; // Performs a BLAS dot product result <- x . y for complex types. Note that // x is unconjugated in this routine. virtual bool DoBlasDotu(Stream *stream, uint64 elem_count, const DeviceMemory> &x, int incx, const DeviceMemory> &y, int incy, DeviceMemory> *result) = 0; virtual bool DoBlasDotu(Stream *stream, uint64 elem_count, const DeviceMemory> &x, int incx, const DeviceMemory> &y, int incy, DeviceMemory> *result) = 0; // Computes the Euclidean norm of a vector: result <- ||x||. // See the following link for more information of Euclidean norm: // http://en.wikipedia.org/wiki/Norm_(mathematics)#Euclidean_norm virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count, const DeviceMemory &x, int incx, DeviceMemory *result) = 0; virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count, const DeviceMemory &x, int incx, DeviceMemory *result) = 0; virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count, const DeviceMemory> &x, int incx, DeviceMemory *result) = 0; virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count, const DeviceMemory> &x, int incx, DeviceMemory *result) = 0; // Performs rotation of points in the plane: // x(i) = c*x(i) + s*y(i) // y(i) = c*y(i) - s*x(i). virtual bool DoBlasRot(Stream *stream, uint64 elem_count, DeviceMemory *x, int incx, DeviceMemory *y, int incy, float c, float s) = 0; virtual bool DoBlasRot(Stream *stream, uint64 elem_count, DeviceMemory *x, int incx, DeviceMemory *y, int incy, double c, double s) = 0; virtual bool DoBlasRot(Stream *stream, uint64 elem_count, DeviceMemory> *x, int incx, DeviceMemory> *y, int incy, float c, float s) = 0; virtual bool DoBlasRot(Stream *stream, uint64 elem_count, DeviceMemory> *x, int incx, DeviceMemory> *y, int incy, double c, double s) = 0; // Computes the parameters for a Givens rotation. // Given the Cartesian coordinates (a, b) of a point, these routines return // the parameters c, s, r, and z associated with the Givens rotation. The // parameters c and s define a unitary matrix such that: // // | c s |.| a | = | r | // | -s c | | b | | 0 | // // The parameter z is defined such that if |a| > |b|, z is s; otherwise if // c is not 0 z is 1/c; otherwise z is 1. virtual bool DoBlasRotg(Stream *stream, DeviceMemory *a, DeviceMemory *b, DeviceMemory *c, DeviceMemory *s) = 0; virtual bool DoBlasRotg(Stream *stream, DeviceMemory *a, DeviceMemory *b, DeviceMemory *c, DeviceMemory *s) = 0; virtual bool DoBlasRotg(Stream *stream, DeviceMemory> *a, DeviceMemory> *b, DeviceMemory *c, DeviceMemory> *s) = 0; virtual bool DoBlasRotg(Stream *stream, DeviceMemory> *a, DeviceMemory> *b, DeviceMemory *c, DeviceMemory> *s) = 0; // Performs modified Givens rotation of points in the plane. // Given two vectors x and y, each vector element of these vectors is replaced // as follows: // // | x(i) | = H | x(i) | // | y(i) | | y(i) | // // for i=1 to n, where H is a modified Givens transformation matrix whose // values are stored in the param[1] through param[4] array. // For more information please Google this routine. virtual bool DoBlasRotm(Stream *stream, uint64 elem_count, DeviceMemory *x, int incx, DeviceMemory *y, int incy, const DeviceMemory ¶m) = 0; virtual bool DoBlasRotm(Stream *stream, uint64 elem_count, DeviceMemory *x, int incx, DeviceMemory *y, int incy, const DeviceMemory ¶m) = 0; // Computes the parameters for a modified Givens rotation. // Given Cartesian coordinates (x1, y1) of an input vector, these routines // compute the components of a modified Givens transformation matrix H that // zeros the y-component of the resulting vector: // // | x1 | = H | x1 * sqrt(d1) | // | 0 | | y1 * sqrt(d1) | // // For more information please Google this routine. virtual bool DoBlasRotmg(Stream *stream, DeviceMemory *d1, DeviceMemory *d2, DeviceMemory *x1, const DeviceMemory &y1, DeviceMemory *param) = 0; virtual bool DoBlasRotmg(Stream *stream, DeviceMemory *d1, DeviceMemory *d2, DeviceMemory *x1, const DeviceMemory &y1, DeviceMemory *param) = 0; // Computes the product of a vector by a scalar: x <- a*x. virtual bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha, DeviceMemory *x, int incx) = 0; virtual bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha, DeviceMemory *x, int incx) = 0; virtual bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha, DeviceMemory> *x, int incx) = 0; virtual bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha, DeviceMemory> *x, int incx) = 0; virtual bool DoBlasScal(Stream *stream, uint64 elem_count, std::complex alpha, DeviceMemory> *x, int incx) = 0; virtual bool DoBlasScal(Stream *stream, uint64 elem_count, std::complex alpha, DeviceMemory> *x, int incx) = 0; // Swaps a vector with another vector. virtual bool DoBlasSwap(Stream *stream, uint64 elem_count, DeviceMemory *x, int incx, DeviceMemory *y, int incy) = 0; virtual bool DoBlasSwap(Stream *stream, uint64 elem_count, DeviceMemory *x, int incx, DeviceMemory *y, int incy) = 0; virtual bool DoBlasSwap(Stream *stream, uint64 elem_count, DeviceMemory> *x, int incx, DeviceMemory> *y, int incy) = 0; virtual bool DoBlasSwap(Stream *stream, uint64 elem_count, DeviceMemory> *x, int incx, DeviceMemory> *y, int incy) = 0; // Finds the index of the element with maximum absolute value. virtual bool DoBlasIamax(Stream *stream, uint64 elem_count, const DeviceMemory &x, int incx, DeviceMemory *result) = 0; virtual bool DoBlasIamax(Stream *stream, uint64 elem_count, const DeviceMemory &x, int incx, DeviceMemory *result) = 0; virtual bool DoBlasIamax(Stream *stream, uint64 elem_count, const DeviceMemory> &x, int incx, DeviceMemory *result) = 0; virtual bool DoBlasIamax(Stream *stream, uint64 elem_count, const DeviceMemory> &x, int incx, DeviceMemory *result) = 0; // Finds the index of the element with minimum absolute value. virtual bool DoBlasIamin(Stream *stream, uint64 elem_count, const DeviceMemory &x, int incx, DeviceMemory *result) = 0; virtual bool DoBlasIamin(Stream *stream, uint64 elem_count, const DeviceMemory &x, int incx, DeviceMemory *result) = 0; virtual bool DoBlasIamin(Stream *stream, uint64 elem_count, const DeviceMemory> &x, int incx, DeviceMemory *result) = 0; virtual bool DoBlasIamin(Stream *stream, uint64 elem_count, const DeviceMemory> &x, int incx, DeviceMemory *result) = 0; // Computes a matrix-vector product using a general band matrix: // // y <- alpha * a * x + beta * y, // or // y <- alpha * a' * x + beta * y, // or // y <- alpha * conj(a') * x + beta * y, // // alpha and beta are scalars; a is an m-by-n general band matrix, with kl // sub-diagonals and ku super-diagonals; x is a vector with // n(trans==kNoTranspose)/m(otherwise) elements; // y is a vector with m(trans==kNoTranspose)/n(otherwise) elements. virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, uint64 kl, uint64 ku, float alpha, const DeviceMemory &a, int lda, const DeviceMemory &x, int incx, float beta, DeviceMemory *y, int incy) = 0; virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, uint64 kl, uint64 ku, double alpha, const DeviceMemory &a, int lda, const DeviceMemory &x, int incx, double beta, DeviceMemory *y, int incy) = 0; virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, uint64 kl, uint64 ku, std::complex alpha, const DeviceMemory> &a, int lda, const DeviceMemory> &x, int incx, std::complex beta, DeviceMemory> *y, int incy) = 0; virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, uint64 kl, uint64 ku, std::complex alpha, const DeviceMemory> &a, int lda, const DeviceMemory> &x, int incx, std::complex beta, DeviceMemory> *y, int incy) = 0; // Computes a matrix-vector product using a general matrix. // // y <- alpha * a * x + beta * y, // or // y <- alpha * a' * x + beta * y, // or // y <- alpha * conj(a') * x + beta * y, // // alpha and beta are scalars; a is an m-by-n general matrix; x is a vector // with n(trans==kNoTranspose)/m(otherwise) elements; // y is a vector with m(trans==kNoTranspose)/n(otherwise) elements. virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha, const DeviceMemory &a, int lda, const DeviceMemory &x, int incx, float beta, DeviceMemory *y, int incy) = 0; virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha, const DeviceMemory &a, int lda, const DeviceMemory &x, int incx, double beta, DeviceMemory *y, int incy) = 0; virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, std::complex alpha, const DeviceMemory> &a, int lda, const DeviceMemory> &x, int incx, std::complex beta, DeviceMemory> *y, int incy) = 0; virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, std::complex alpha, const DeviceMemory> &a, int lda, const DeviceMemory> &x, int incx, std::complex beta, DeviceMemory> *y, int incy) = 0; virtual bool DoBlasGemvWithProfiling( Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha, const DeviceMemory &a, int lda, const DeviceMemory &x, int incx, float beta, DeviceMemory *y, int incy, ProfileResult *output_profile_result) = 0; virtual bool DoBlasGemvWithProfiling( Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha, const DeviceMemory &a, int lda, const DeviceMemory &x, int incx, double beta, DeviceMemory *y, int incy, ProfileResult *output_profile_result) = 0; virtual bool DoBlasGemvWithProfiling( Stream *stream, blas::Transpose trans, uint64 m, uint64 n, std::complex alpha, const DeviceMemory> &a, int lda, const DeviceMemory> &x, int incx, std::complex beta, DeviceMemory> *y, int incy, ProfileResult *output_profile_result) = 0; virtual bool DoBlasGemvWithProfiling( Stream *stream, blas::Transpose trans, uint64 m, uint64 n, std::complex alpha, const DeviceMemory> &a, int lda, const DeviceMemory> &x, int incx, std::complex beta, DeviceMemory> *y, int incy, ProfileResult *output_profile_result) = 0; // Performs a rank-1 update of a general matrix. // // a <- alpha * x * y' + a, // // alpha is a scalar; x is an m-element vector; y is an n-element vector; a is // an m-by-n general matrix. virtual bool DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha, const DeviceMemory &x, int incx, const DeviceMemory &y, int incy, DeviceMemory *a, int lda) = 0; virtual bool DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha, const DeviceMemory &x, int incx, const DeviceMemory &y, int incy, DeviceMemory *a, int lda) = 0; // Performs a rank-1 update (conjugated) of a general matrix. // // a <- alpha * x * conj(y') + a, // // alpha is a scalar; x is an m-element vector; y is an n-element vector; a is // an m-by-n general matrix. virtual bool DoBlasGerc(Stream *stream, uint64 m, uint64 n, std::complex alpha, const DeviceMemory> &x, int incx, const DeviceMemory> &y, int incy, DeviceMemory> *a, int lda) = 0; virtual bool DoBlasGerc(Stream *stream, uint64 m, uint64 n, std::complex alpha, const DeviceMemory> &x, int incx, const DeviceMemory> &y, int incy, DeviceMemory> *a, int lda) = 0; // Performs a rank-1 update (unconjugated) of a general matrix. // // a <- alpha * x * y' + a, // // alpha is a scalar; x is an m-element vector; y is an n-element vector; a is // an m-by-n general matrix. virtual bool DoBlasGeru(Stream *stream, uint64 m, uint64 n, std::complex alpha, const DeviceMemory> &x, int incx, const DeviceMemory> &y, int incy, DeviceMemory> *a, int lda) = 0; virtual bool DoBlasGeru(Stream *stream, uint64 m, uint64 n, std::complex alpha, const DeviceMemory> &x, int incx, const DeviceMemory> &y, int incy, DeviceMemory> *a, int lda) = 0; // Computes a matrix-vector product using a Hermitian band matrix. // // y <- alpha * a * x + beta * y, // // alpha and beta are scalars; a is an n-by-n Hermitian band matrix, with k // super-diagonals; x and y are n-element vectors. virtual bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k, std::complex alpha, const DeviceMemory> &a, int lda, const DeviceMemory> &x, int incx, std::complex beta, DeviceMemory> *y, int incy) = 0; virtual bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k, std::complex alpha, const DeviceMemory> &a, int lda, const DeviceMemory> &x, int incx, std::complex beta, DeviceMemory> *y, int incy) = 0; // Computes a matrix-vector product using a Hermitian matrix. // // y <- alpha * a * x + beta * y, // // alpha and beta are scalars; a is an n-by-n Hermitian matrix; x and y are // n-element vectors. virtual bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n, std::complex alpha, const DeviceMemory> &a, int lda, const DeviceMemory> &x, int incx, std::complex beta, DeviceMemory> *y, int incy) = 0; virtual bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n, std::complex alpha, const DeviceMemory> &a, int lda, const DeviceMemory> &x, int incx, std::complex beta, DeviceMemory> *y, int incy) = 0; // Performs a rank-1 update of a Hermitian matrix. // // a <- alpha * x * conj(x') + a, // // alpha is a scalar; x is an n-element vector; a is an n-by-n Hermitian // matrix. virtual bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, const DeviceMemory> &x, int incx, DeviceMemory> *a, int lda) = 0; virtual bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, double alpha, const DeviceMemory> &x, int incx, DeviceMemory> *a, int lda) = 0; // Performs a rank-2 update of a Hermitian matrix. // // a <- alpha * x * conj(x') + conj(alpha) * y * conj(x') + a, // // alpha is a scalar; x and y are n-element vectors; a is an n-by-n Hermitian // matrix. virtual bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n, std::complex alpha, const DeviceMemory> &x, int incx, const DeviceMemory> &y, int incy, DeviceMemory> *a, int lda) = 0; virtual bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n, std::complex alpha, const DeviceMemory> &x, int incx, const DeviceMemory> &y, int incy, DeviceMemory> *a, int lda) = 0; // Computes a matrix-vector product using a Hermitian packed matrix. // // y <- alpha * a * x + beta * y, // // alpha and beta are scalars; a is an n-by-n Hermitian matrix, supplied in // packed form; x and y are n-element vectors. virtual bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n, std::complex alpha, const DeviceMemory> &ap, const DeviceMemory> &x, int incx, std::complex beta, DeviceMemory> *y, int incy) = 0; virtual bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n, std::complex alpha, const DeviceMemory> &ap, const DeviceMemory> &x, int incx, std::complex beta, DeviceMemory> *y, int incy) = 0; // Performs a rank-1 update of a Hermitian packed matrix. // // a <- alpha * x * conj(x') + a, // // alpha is a scalar; x is an n-element vector; a is an n-by-n Hermitian // matrix, supplied in packed form. virtual bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, const DeviceMemory> &x, int incx, DeviceMemory> *ap) = 0; virtual bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, double alpha, const DeviceMemory> &x, int incx, DeviceMemory> *ap) = 0; // Performs a rank-2 update of a Hermitian packed matrix. // // a <- alpha * x * conj(x') + conj(alpha) * y * conj(x') + a, // // alpha is a scalar; x and y are n-element vectors; a is an n-by-n Hermitian // matrix, supplied in packed form. virtual bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n, std::complex alpha, const DeviceMemory> &x, int incx, const DeviceMemory> &y, int incy, DeviceMemory> *ap) = 0; virtual bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n, std::complex alpha, const DeviceMemory> &x, int incx, const DeviceMemory> &y, int incy, DeviceMemory> *ap) = 0; // Computes a matrix-vector product using a symmetric band matrix. // // y <- alpha * a * x + beta * y, // // alpha and beta are scalars; a is an n-by-n symmetric band matrix, with k // super-diagonals; x and y are n-element vectors. virtual bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k, float alpha, const DeviceMemory &a, int lda, const DeviceMemory &x, int incx, float beta, DeviceMemory *y, int incy) = 0; virtual bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k, double alpha, const DeviceMemory &a, int lda, const DeviceMemory &x, int incx, double beta, DeviceMemory *y, int incy) = 0; // Computes a matrix-vector product using a symmetric packed matrix. // // y <- alpha * a * x + beta * y, // // alpha and beta are scalars; a is an n-by-n symmetric matrix, supplied in // packed form; x and y are n-element vectors. virtual bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, const DeviceMemory &ap, const DeviceMemory &x, int incx, float beta, DeviceMemory *y, int incy) = 0; virtual bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n, double alpha, const DeviceMemory &ap, const DeviceMemory &x, int incx, double beta, DeviceMemory *y, int incy) = 0; // Performs a rank-1 update of a symmetric packed matrix. // // a <- alpha * x * x' + a, // // alpha is a scalar; x is an n-element vector; a is an n-by-n symmetric // matrix, supplied in packed form. virtual bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, const DeviceMemory &x, int incx, DeviceMemory *ap) = 0; virtual bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, double alpha, const DeviceMemory &x, int incx, DeviceMemory *ap) = 0; // Performs a rank-2 update of a symmetric packed matrix. // // a <- alpha * x * x' + alpha * y * x' + a, // // alpha is a scalar; x and y are n-element vectors; a is an n-by-n symmetric // matrix, supplied in packed form. virtual bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, const DeviceMemory &x, int incx, const DeviceMemory &y, int incy, DeviceMemory *ap) = 0; virtual bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n, double alpha, const DeviceMemory &x, int incx, const DeviceMemory &y, int incy, DeviceMemory *ap) = 0; // Computes a matrix-vector product for a symmetric matrix. // // y <- alpha * a * x + beta * y, // // alpha and beta are scalars; a is an n-by-n symmetric matrix; x and y are // n-element vectors. virtual bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, const DeviceMemory &a, int lda, const DeviceMemory &x, int incx, float beta, DeviceMemory *y, int incy) = 0; virtual bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n, double alpha, const DeviceMemory &a, int lda, const DeviceMemory &x, int incx, double beta, DeviceMemory *y, int incy) = 0; // Performs a rank-1 update of a symmetric matrix. // // a <- alpha * x * x' + a, // // alpha is a scalar; x is an n-element vector; a is an n-by-n symmetric // matrix. virtual bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, const DeviceMemory &x, int incx, DeviceMemory *a, int lda) = 0; virtual bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n, double alpha, const DeviceMemory &x, int incx, DeviceMemory *a, int lda) = 0; // Performs a rank-2 update of symmetric matrix. // // a <- alpha * x * x' + alpha * y * x' + a, // // alpha is a scalar; x and y are n-element vectors; a is an n-by-n symmetric // matrix. virtual bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, const DeviceMemory &x, int incx, const DeviceMemory &y, int incy, DeviceMemory *a, int lda) = 0; virtual bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n, double alpha, const DeviceMemory &x, int incx, const DeviceMemory &y, int incy, DeviceMemory *a, int lda) = 0; // Computes a matrix-vector product using a triangular band matrix. // // x <- a * x, // or // x <- a' * x, // or // x <- conj(a') * x, // // a is an n-by-n unit, or non-unit, upper or lower triangular band matrix, // with k+1 diagonals; x is a n-element vector. virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, blas::Transpose trans, blas::Diagonal diag, uint64 n, uint64 k, const DeviceMemory &a, int lda, DeviceMemory *x, int incx) = 0; virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, blas::Transpose trans, blas::Diagonal diag, uint64 n, uint64 k, const DeviceMemory &a, int lda, DeviceMemory *x, int incx) = 0; virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, blas::Transpose trans, blas::Diagonal diag, uint64 n, uint64 k, const DeviceMemory> &a, int lda, DeviceMemory> *x, int incx) = 0; virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, blas::Transpose trans, blas::Diagonal diag, uint64 n, uint64 k, const DeviceMemory> &a, int lda, DeviceMemory> *x, int incx) = 0; // Solves a system of linear equations whose coefficients are in a triangular // band matrix as below: // // a * x = b, // or // a' * x = b, // or // conj(a') * x = b, // // b and x are n-element vectors; a is an n-by-n unit, or non-unit, upper or // lower triangular band matrix, with k+1 diagonals. virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, blas::Transpose trans, blas::Diagonal diag, uint64 n, uint64 k, const DeviceMemory &a, int lda, DeviceMemory *x, int incx) = 0; virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, blas::Transpose trans, blas::Diagonal diag, uint64 n, uint64 k, const DeviceMemory &a, int lda, DeviceMemory *x, int incx) = 0; virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, blas::Transpose trans, blas::Diagonal diag, uint64 n, uint64 k, const DeviceMemory> &a, int lda, DeviceMemory> *x, int incx) = 0; virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, blas::Transpose trans, blas::Diagonal diag, uint64 n, uint64 k, const DeviceMemory> &a, int lda, DeviceMemory> *x, int incx) = 0; // Computes a matrix-vector product using a triangular packed matrix. // // x <- a * x, // or // x <- a' * x, // or // x <- conj(a') * x, // // a is an n-by-n unit, or non-unit, upper or lower triangular matrix, // supplied in packed form; x is a n-element vector. virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, blas::Transpose trans, blas::Diagonal diag, uint64 n, const DeviceMemory &ap, DeviceMemory *x, int incx) = 0; virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, blas::Transpose trans, blas::Diagonal diag, uint64 n, const DeviceMemory &ap, DeviceMemory *x, int incx) = 0; virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, blas::Transpose trans, blas::Diagonal diag, uint64 n, const DeviceMemory> &ap, DeviceMemory> *x, int incx) = 0; virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, blas::Transpose trans, blas::Diagonal diag, uint64 n, const DeviceMemory> &ap, DeviceMemory> *x, int incx) = 0; // Solves a system of linear equations whose coefficients are in a triangular // packed matrix as below: // // a * x = b, // or // a' * x = b, // or // conj(a') * x = b, // // b and x are n-element vectors; a is an n-by-n unit, or non-unit, upper or // lower triangular matrix, supplied in packed form. virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, blas::Transpose trans, blas::Diagonal diag, uint64 n, const DeviceMemory &ap, DeviceMemory *x, int incx) = 0; virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, blas::Transpose trans, blas::Diagonal diag, uint64 n, const DeviceMemory &ap, DeviceMemory *x, int incx) = 0; virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, blas::Transpose trans, blas::Diagonal diag, uint64 n, const DeviceMemory> &ap, DeviceMemory> *x, int incx) = 0; virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, blas::Transpose trans, blas::Diagonal diag, uint64 n, const DeviceMemory> &ap, DeviceMemory> *x, int incx) = 0; // Computes a matrix-vector product using a triangular matrix. // // x <- a * x, // or // x <- a' * x, // or // x <- conj(a') * x, // // a is an n-by-n unit, or non-unit, upper or lower triangular matrix; x is a // n-element vector. virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, blas::Transpose trans, blas::Diagonal diag, uint64 n, const DeviceMemory &a, int lda, DeviceMemory *x, int incx) = 0; virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, blas::Transpose trans, blas::Diagonal diag, uint64 n, const DeviceMemory &a, int lda, DeviceMemory *x, int incx) = 0; virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, blas::Transpose trans, blas::Diagonal diag, uint64 n, const DeviceMemory> &a, int lda, DeviceMemory> *x, int incx) = 0; virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, blas::Transpose trans, blas::Diagonal diag, uint64 n, const DeviceMemory> &a, int lda, DeviceMemory> *x, int incx) = 0; // Solves a system of linear equations whose coefficients are in a triangular // matrix as below: // // a * x = b, // or // a' * x = b, // or // conj(a') * x = b, // // b and x are n-element vectors; a is an n-by-n unit, or non-unit, upper or // lower triangular matrix. virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, blas::Transpose trans, blas::Diagonal diag, uint64 n, const DeviceMemory &a, int lda, DeviceMemory *x, int incx) = 0; virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, blas::Transpose trans, blas::Diagonal diag, uint64 n, const DeviceMemory &a, int lda, DeviceMemory *x, int incx) = 0; virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, blas::Transpose trans, blas::Diagonal diag, uint64 n, const DeviceMemory> &a, int lda, DeviceMemory> *x, int incx) = 0; virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, blas::Transpose trans, blas::Diagonal diag, uint64 n, const DeviceMemory> &a, int lda, DeviceMemory> *x, int incx) = 0; // Computes a matrix-matrix product with general matrices: // // c <- alpha * op(a) * op(b) + beta * c, // // op(X) is one of op(X) = X, or op(X) = X', or op(X) = conj(X'); alpha and // beta are scalars; a, b, and c are matrices; op(a) is an m-by-k matrix; // op(b) is a k-by-n matrix; c is an m-by-n matrix. // // Note: The half interface uses float precision internally; the version // that uses half precision internally is not yet supported. There is no // batched version of the half-precision interface. virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory &a, int lda, const DeviceMemory &b, int ldb, float beta, DeviceMemory *c, int ldc) = 0; virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory &a, int lda, const DeviceMemory &b, int ldb, float beta, DeviceMemory *c, int ldc) = 0; virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, double alpha, const DeviceMemory &a, int lda, const DeviceMemory &b, int ldb, double beta, DeviceMemory *c, int ldc) = 0; virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, std::complex alpha, const DeviceMemory> &a, int lda, const DeviceMemory> &b, int ldb, std::complex beta, DeviceMemory> *c, int ldc) = 0; virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, std::complex alpha, const DeviceMemory> &a, int lda, const DeviceMemory> &b, int ldb, std::complex beta, DeviceMemory> *c, int ldc) = 0; virtual bool DoBlasGemmWithProfiling( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory &a, int lda, const DeviceMemory &b, int ldb, float beta, DeviceMemory *c, int ldc, ProfileResult *output_profile_result) = 0; virtual bool DoBlasGemmWithProfiling( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory &a, int lda, const DeviceMemory &b, int ldb, float beta, DeviceMemory *c, int ldc, ProfileResult *output_profile_result) = 0; virtual bool DoBlasGemmWithProfiling( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, double alpha, const DeviceMemory &a, int lda, const DeviceMemory &b, int ldb, double beta, DeviceMemory *c, int ldc, ProfileResult *output_profile_result) = 0; virtual bool DoBlasGemmWithProfiling( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, std::complex alpha, const DeviceMemory> &a, int lda, const DeviceMemory> &b, int ldb, std::complex beta, DeviceMemory> *c, int ldc, ProfileResult *output_profile_result) = 0; virtual bool DoBlasGemmWithProfiling( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, std::complex alpha, const DeviceMemory> &a, int lda, const DeviceMemory> &b, int ldb, std::complex beta, DeviceMemory> *c, int ldc, ProfileResult *output_profile_result) = 0; // Gets a list of supported algorithms for DoBlasGemmWithAlgorithm. virtual bool GetBlasGemmAlgorithms( std::vector *out_algorithms) = 0; // Like DoBlasGemm, but accepts an algorithm and an compute type. // // The compute type lets you say (e.g.) that the inputs and outputs are // Eigen::halfs, but you want the internal computations to be done with // float32 precision. // // Note the subtle difference in the version that accepts Eigen:::half -- // alpha and beta have type const Eigen::half&, not float. // // If output_profile_result is not null, a failure here does not put the // stream in a failure state. Instead, success/failure is indicated by // output_profile_result->is_valid(). This lets you use this function for // choosing the best algorithm among many (some of which may fail) without // creating a new Stream for each attempt. virtual bool DoBlasGemmWithAlgorithm( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, const HostOrDeviceScalar &alpha, const DeviceMemory &a, int lda, const DeviceMemory &b, int ldb, const HostOrDeviceScalar &beta, DeviceMemory *c, int ldc, ComputationType computation_type, AlgorithmType algorithm, ProfileResult *output_profile_result) = 0; virtual bool DoBlasGemmWithAlgorithm( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, const HostOrDeviceScalar &alpha, const DeviceMemory &a, int lda, const DeviceMemory &b, int ldb, const HostOrDeviceScalar &beta, DeviceMemory *c, int ldc, ComputationType computation_type, AlgorithmType algorithm, ProfileResult *output_profile_result) = 0; virtual bool DoBlasGemmWithAlgorithm( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, const HostOrDeviceScalar &alpha, const DeviceMemory &a, int lda, const DeviceMemory &b, int ldb, const HostOrDeviceScalar &beta, DeviceMemory *c, int ldc, ComputationType computation_type, AlgorithmType algorithm, ProfileResult *output_profile_result) = 0; virtual bool DoBlasGemmWithAlgorithm( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, const HostOrDeviceScalar &alpha, const DeviceMemory &a, int lda, const DeviceMemory &b, int ldb, const HostOrDeviceScalar &beta, DeviceMemory *c, int ldc, ComputationType computation_type, AlgorithmType algorithm, ProfileResult *output_profile_result) = 0; virtual bool DoBlasGemmWithAlgorithm( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, const HostOrDeviceScalar> &alpha, const DeviceMemory> &a, int lda, const DeviceMemory> &b, int ldb, const HostOrDeviceScalar> &beta, DeviceMemory> *c, int ldc, ComputationType computation_type, AlgorithmType algorithm, ProfileResult *output_profile_result) = 0; virtual bool DoBlasGemmWithAlgorithm( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, const HostOrDeviceScalar> &alpha, const DeviceMemory> &a, int lda, const DeviceMemory> &b, int ldb, const HostOrDeviceScalar> &beta, DeviceMemory> *c, int ldc, ComputationType computation_type, AlgorithmType algorithm, ProfileResult *output_profile_result) = 0; // Computes a batch of matrix-matrix product with general matrices. // This is a batched version of DoBlasGemm. // The batched GEMM computes matrix product for each input/output in a, b, // and c, which contain batch_count DeviceMemory objects. virtual bool DoBlasGemmBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, float alpha, const port::ArraySlice *> &a, int lda, const port::ArraySlice *> &b, int ldb, float beta, const port::ArraySlice *> &c, int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0; virtual bool DoBlasGemmBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, float alpha, const port::ArraySlice *> &a, int lda, const port::ArraySlice *> &b, int ldb, float beta, const port::ArraySlice *> &c, int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0; virtual bool DoBlasGemmBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, double alpha, const port::ArraySlice *> &a, int lda, const port::ArraySlice *> &b, int ldb, double beta, const port::ArraySlice *> &c, int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0; virtual bool DoBlasGemmBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, std::complex alpha, const port::ArraySlice> *> &a, int lda, const port::ArraySlice> *> &b, int ldb, std::complex beta, const port::ArraySlice> *> &c, int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0; virtual bool DoBlasGemmBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, std::complex alpha, const port::ArraySlice> *> &a, int lda, const port::ArraySlice> *> &b, int ldb, std::complex beta, const port::ArraySlice> *> &c, int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0; // Batched gemm with strides instead of pointer arrays. virtual bool DoBlasGemmStridedBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory &a, int lda, int64 stride_a, const DeviceMemory &b, int ldb, int64 stride_b, float beta, DeviceMemory *c, int ldc, int64 stride_c, int batch_count) = 0; virtual bool DoBlasGemmStridedBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory &a, int lda, int64 stride_a, const DeviceMemory &b, int ldb, int64 stride_b, float beta, DeviceMemory *c, int ldc, int64 stride_c, int batch_count) = 0; virtual bool DoBlasGemmStridedBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, double alpha, const DeviceMemory &a, int lda, int64 stride_a, const DeviceMemory &b, int ldb, int64 stride_b, double beta, DeviceMemory *c, int ldc, int64 stride_c, int batch_count) = 0; virtual bool DoBlasGemmStridedBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, std::complex alpha, const DeviceMemory> &a, int lda, int64 stride_a, const DeviceMemory> &b, int ldb, int64 stride_b, std::complex beta, DeviceMemory> *c, int ldc, int64 stride_c, int batch_count) = 0; virtual bool DoBlasGemmStridedBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, std::complex alpha, const DeviceMemory> &a, int lda, int64 stride_a, const DeviceMemory> &b, int ldb, int64 stride_b, std::complex beta, DeviceMemory> *c, int ldc, int64 stride_c, int batch_count) = 0; // Computes a matrix-matrix product where one input matrix is Hermitian: // // c <- alpha * a * b + beta * c, // or // c <- alpha * b * a + beta * c, // // alpha and beta are scalars; a is a Hermitian matrix; b and c are m-by-n // matrices. virtual bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo, uint64 m, uint64 n, std::complex alpha, const DeviceMemory> &a, int lda, const DeviceMemory> &b, int ldb, std::complex beta, DeviceMemory> *c, int ldc) = 0; virtual bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo, uint64 m, uint64 n, std::complex alpha, const DeviceMemory> &a, int lda, const DeviceMemory> &b, int ldb, std::complex beta, DeviceMemory> *c, int ldc) = 0; // Performs a Hermitian rank-k update. // // c <- alpha * a * conj(a') + beta * c, // or // c <- alpha * conj(a') * a + beta * c, // // alpha and beta are scalars; c is a n-by-n Hermitian matrix; a is an n-by-k // matrix in the first case and a k-by-n matrix in the second case. virtual bool DoBlasHerk(Stream *stream, blas::UpperLower uplo, blas::Transpose trans, uint64 n, uint64 k, float alpha, const DeviceMemory> &a, int lda, float beta, DeviceMemory> *c, int ldc) = 0; virtual bool DoBlasHerk(Stream *stream, blas::UpperLower uplo, blas::Transpose trans, uint64 n, uint64 k, double alpha, const DeviceMemory> &a, int lda, double beta, DeviceMemory> *c, int ldc) = 0; // Performs a Hermitian rank-2k update. // // c <- alpha * a * conj(b') + conj(alpha) * b * conj(a') + beta * c, // or // c <- alpha * conj(b') * a + conj(alpha) * conj(a') * b + beta * c, // // alpha and beta are scalars; c is a n-by-n Hermitian matrix; a and b are // n-by-k matrices in the first case and k-by-n matrices in the second case. virtual bool DoBlasHer2k(Stream *stream, blas::UpperLower uplo, blas::Transpose trans, uint64 n, uint64 k, std::complex alpha, const DeviceMemory> &a, int lda, const DeviceMemory> &b, int ldb, float beta, DeviceMemory> *c, int ldc) = 0; virtual bool DoBlasHer2k(Stream *stream, blas::UpperLower uplo, blas::Transpose trans, uint64 n, uint64 k, std::complex alpha, const DeviceMemory> &a, int lda, const DeviceMemory> &b, int ldb, double beta, DeviceMemory> *c, int ldc) = 0; // Computes a matrix-matrix product where one input matrix is symmetric. // // c <- alpha * a * b + beta * c, // or // c <- alpha * b * a + beta * c, // // alpha and beta are scalars; a is a symmetric matrix; b and c are m-by-n // matrices. virtual bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo, uint64 m, uint64 n, float alpha, const DeviceMemory &a, int lda, const DeviceMemory &b, int ldb, float beta, DeviceMemory *c, int ldc) = 0; virtual bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo, uint64 m, uint64 n, double alpha, const DeviceMemory &a, int lda, const DeviceMemory &b, int ldb, double beta, DeviceMemory *c, int ldc) = 0; virtual bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo, uint64 m, uint64 n, std::complex alpha, const DeviceMemory> &a, int lda, const DeviceMemory> &b, int ldb, std::complex beta, DeviceMemory> *c, int ldc) = 0; virtual bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo, uint64 m, uint64 n, std::complex alpha, const DeviceMemory> &a, int lda, const DeviceMemory> &b, int ldb, std::complex beta, DeviceMemory> *c, int ldc) = 0; // Performs a symmetric rank-k update. // // c <- alpha * a * a' + beta * c, // or // c <- alpha * a' * a + beta * c, // // alpha and beta are scalars; c is a n-by-n symmetric matrix; a is an n-by-k // matrix in the first case and a k-by-n matrix in the second case. virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, blas::Transpose trans, uint64 n, uint64 k, float alpha, const DeviceMemory &a, int lda, float beta, DeviceMemory *c, int ldc) = 0; virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, blas::Transpose trans, uint64 n, uint64 k, double alpha, const DeviceMemory &a, int lda, double beta, DeviceMemory *c, int ldc) = 0; virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, blas::Transpose trans, uint64 n, uint64 k, std::complex alpha, const DeviceMemory> &a, int lda, std::complex beta, DeviceMemory> *c, int ldc) = 0; virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, blas::Transpose trans, uint64 n, uint64 k, std::complex alpha, const DeviceMemory> &a, int lda, std::complex beta, DeviceMemory> *c, int ldc) = 0; // Performs a symmetric rank-2k update. // // c <- alpha * a * b' + alpha * b * a' + beta * c, // or // c <- alpha * b' * a + alpha * a' * b + beta * c, // // alpha and beta are scalars; c is a n-by-n symmetric matrix; a and b are // n-by-k matrices in the first case and k-by-n matrices in the second case. virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, blas::Transpose trans, uint64 n, uint64 k, float alpha, const DeviceMemory &a, int lda, const DeviceMemory &b, int ldb, float beta, DeviceMemory *c, int ldc) = 0; virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, blas::Transpose trans, uint64 n, uint64 k, double alpha, const DeviceMemory &a, int lda, const DeviceMemory &b, int ldb, double beta, DeviceMemory *c, int ldc) = 0; virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, blas::Transpose trans, uint64 n, uint64 k, std::complex alpha, const DeviceMemory> &a, int lda, const DeviceMemory> &b, int ldb, std::complex beta, DeviceMemory> *c, int ldc) = 0; virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, blas::Transpose trans, uint64 n, uint64 k, std::complex alpha, const DeviceMemory> &a, int lda, const DeviceMemory> &b, int ldb, std::complex beta, DeviceMemory> *c, int ldc) = 0; // Computes a matrix-matrix product where one input matrix is triangular. // // b <- alpha * op(a) * b, // or // b <- alpha * b * op(a) // // alpha is a scalar; b is an m-by-n matrix; a is a unit, or non-unit, upper // or lower triangular matrix; op(a) is one of op(a) = a, or op(a) = a', or // op(a) = conj(a'). virtual bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo, blas::Transpose transa, blas::Diagonal diag, uint64 m, uint64 n, float alpha, const DeviceMemory &a, int lda, DeviceMemory *b, int ldb) = 0; virtual bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo, blas::Transpose transa, blas::Diagonal diag, uint64 m, uint64 n, double alpha, const DeviceMemory &a, int lda, DeviceMemory *b, int ldb) = 0; virtual bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo, blas::Transpose transa, blas::Diagonal diag, uint64 m, uint64 n, std::complex alpha, const DeviceMemory> &a, int lda, DeviceMemory> *b, int ldb) = 0; virtual bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo, blas::Transpose transa, blas::Diagonal diag, uint64 m, uint64 n, std::complex alpha, const DeviceMemory> &a, int lda, DeviceMemory> *b, int ldb) = 0; // Solves a triangular matrix equation. // // op(a) * x = alpha * b, // or // x * op(a) = alpha * b // // alpha is a scalar; x and b are m-by-n matrices; a is a unit, or non-unit, // upper or lower triangular matrix; op(a) is one of op(a) = a, or op(a) = a', // or op(a) = conj(a'). virtual bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, blas::Transpose transa, blas::Diagonal diag, uint64 m, uint64 n, float alpha, const DeviceMemory &a, int lda, DeviceMemory *b, int ldb) = 0; virtual bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, blas::Transpose transa, blas::Diagonal diag, uint64 m, uint64 n, double alpha, const DeviceMemory &a, int lda, DeviceMemory *b, int ldb) = 0; virtual bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, blas::Transpose transa, blas::Diagonal diag, uint64 m, uint64 n, std::complex alpha, const DeviceMemory> &a, int lda, DeviceMemory> *b, int ldb) = 0; virtual bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, blas::Transpose transa, blas::Diagonal diag, uint64 m, uint64 n, std::complex alpha, const DeviceMemory> &a, int lda, DeviceMemory> *b, int ldb) = 0; // Creates a backend-specific plan object for a blaslt matmul operation, which // can then be passed to DoBlasLtMatmul(). When possible, plans should be // created once and reused for multiple calls to DoBlasLtMatmul(). virtual port::StatusOr> CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams ¶ms) = 0; // Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are // returned in the order of increasing estimated compute time according to an // internal heuristic. The first returned algorithm can be used as the default // algorithm if no autotuning is to be performed. virtual port::StatusOr< std::vector>> GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan, size_t max_workspace_size, int max_algorithm_count) = 0; // Executes a blaslt matmul operation on the stream. If output_profile_result // is not nullptr, the operation is profiled, error messages are // suppressed, and output_profile_result->algorithm() is set to // algorithm->index(). If epilogue was set to kBias or kBiasThenReLU when // creating the plan, the bias argument here must refer to a valid device // vector of length equal to the number of rows in matrix c. If epilogue was // set to any other value then the bias argument here must be null. The bias // vector is broadcast across the batch dimension. // Note that the data types of a and b (c and bias) must match the ab_type // (c_type) with which the plan was created, and the data types of alpha and // beta must match the data type of c. virtual bool DoBlasLtMatmul( Stream *stream, const blas::IBlasLtMatmulPlan *plan, const HostOrDeviceScalar &alpha, DeviceMemoryBase a, DeviceMemoryBase b, const HostOrDeviceScalar &beta, DeviceMemoryBase c, ScratchAllocator *scratch_allocator, const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias, blas::ProfileResult *output_profile_result) = 0; template bool DoBlasLtMatmul(Stream *stream, const blas::IBlasLtMatmulPlan *plan, const HostOrDeviceScalar &alpha, const DeviceMemory &a, const DeviceMemory &b, const HostOrDeviceScalar &beta, DeviceMemory *c, ScratchAllocator *scratch_allocator, const blas::IBlasLtMatmulAlgorithm *algorithm, const DeviceMemory &bias = {}, blas::ProfileResult *output_profile_result = nullptr) { constexpr blas::DataType ab_type = blas::ToDataType::value; if (ab_type != plan->ab_type()) { VLOG(2) << "DoBlasLtMatmul returning false because a and b type does " "not match plan: expected " << plan->ab_type() << ", got " << ab_type; return false; } constexpr blas::DataType c_type = blas::ToDataType::value; if (c_type != plan->c_type()) { VLOG(2) << "DoBlasLtMatmul returning false because c type does " "not match plan: expected " << plan->c_type() << ", got " << c_type; return false; } return DoBlasLtMatmul(stream, plan, alpha, a, b, beta, *c, scratch_allocator, algorithm, bias, output_profile_result); } virtual port::Status GetVersion(std::string *version) = 0; protected: BlasSupport() {} private: SE_DISALLOW_COPY_AND_ASSIGN(BlasSupport); }; // Macro used to quickly declare overrides for abstract virtuals in the // BlasSupport base class. #define TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES \ bool DoBlasAsum(Stream *stream, uint64 elem_count, \ const DeviceMemory &x, int incx, \ DeviceMemory *result) override; \ bool DoBlasAsum(Stream *stream, uint64 elem_count, \ const DeviceMemory &x, int incx, \ DeviceMemory *result) override; \ bool DoBlasAsum(Stream *stream, uint64 elem_count, \ const DeviceMemory> &x, int incx, \ DeviceMemory *result) override; \ bool DoBlasAsum(Stream *stream, uint64 elem_count, \ const DeviceMemory> &x, int incx, \ DeviceMemory *result) override; \ bool DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha, \ const DeviceMemory &x, int incx, \ DeviceMemory *y, int incy) override; \ bool DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha, \ const DeviceMemory &x, int incx, \ DeviceMemory *y, int incy) override; \ bool DoBlasAxpy(Stream *stream, uint64 elem_count, \ std::complex alpha, \ const DeviceMemory> &x, int incx, \ DeviceMemory> *y, int incy) override; \ bool DoBlasAxpy(Stream *stream, uint64 elem_count, \ std::complex alpha, \ const DeviceMemory> &x, int incx, \ DeviceMemory> *y, int incy) override; \ bool DoBlasCopy(Stream *stream, uint64 elem_count, \ const DeviceMemory &x, int incx, \ DeviceMemory *y, int incy) override; \ bool DoBlasCopy(Stream *stream, uint64 elem_count, \ const DeviceMemory &x, int incx, \ DeviceMemory *y, int incy) override; \ bool DoBlasCopy(Stream *stream, uint64 elem_count, \ const DeviceMemory> &x, int incx, \ DeviceMemory> *y, int incy) override; \ bool DoBlasCopy(Stream *stream, uint64 elem_count, \ const DeviceMemory> &x, int incx, \ DeviceMemory> *y, int incy) override; \ bool DoBlasDot(Stream *stream, uint64 elem_count, \ const DeviceMemory &x, int incx, \ const DeviceMemory &y, int incy, \ DeviceMemory *result) override; \ bool DoBlasDot(Stream *stream, uint64 elem_count, \ const DeviceMemory &x, int incx, \ const DeviceMemory &y, int incy, \ DeviceMemory *result) override; \ bool DoBlasDotc(Stream *stream, uint64 elem_count, \ const DeviceMemory> &x, int incx, \ const DeviceMemory> &y, int incy, \ DeviceMemory> *result) override; \ bool DoBlasDotc(Stream *stream, uint64 elem_count, \ const DeviceMemory> &x, int incx, \ const DeviceMemory> &y, int incy, \ DeviceMemory> *result) override; \ bool DoBlasDotu(Stream *stream, uint64 elem_count, \ const DeviceMemory> &x, int incx, \ const DeviceMemory> &y, int incy, \ DeviceMemory> *result) override; \ bool DoBlasDotu(Stream *stream, uint64 elem_count, \ const DeviceMemory> &x, int incx, \ const DeviceMemory> &y, int incy, \ DeviceMemory> *result) override; \ bool DoBlasNrm2(Stream *stream, uint64 elem_count, \ const DeviceMemory &x, int incx, \ DeviceMemory *result) override; \ bool DoBlasNrm2(Stream *stream, uint64 elem_count, \ const DeviceMemory &x, int incx, \ DeviceMemory *result) override; \ bool DoBlasNrm2(Stream *stream, uint64 elem_count, \ const DeviceMemory> &x, int incx, \ DeviceMemory *result) override; \ bool DoBlasNrm2(Stream *stream, uint64 elem_count, \ const DeviceMemory> &x, int incx, \ DeviceMemory *result) override; \ bool DoBlasRot(Stream *stream, uint64 elem_count, DeviceMemory *x, \ int incx, DeviceMemory *y, int incy, float c, float s) \ override; \ bool DoBlasRot(Stream *stream, uint64 elem_count, DeviceMemory *x, \ int incx, DeviceMemory *y, int incy, double c, \ double s) override; \ bool DoBlasRot(Stream *stream, uint64 elem_count, \ DeviceMemory> *x, int incx, \ DeviceMemory> *y, int incy, float c, \ float s) override; \ bool DoBlasRot(Stream *stream, uint64 elem_count, \ DeviceMemory> *x, int incx, \ DeviceMemory> *y, int incy, double c, \ double s) override; \ bool DoBlasRotg(Stream *stream, DeviceMemory *a, \ DeviceMemory *b, DeviceMemory *c, \ DeviceMemory *s) override; \ bool DoBlasRotg(Stream *stream, DeviceMemory *a, \ DeviceMemory *b, DeviceMemory *c, \ DeviceMemory *s) override; \ bool DoBlasRotg(Stream *stream, DeviceMemory> *a, \ DeviceMemory> *b, \ DeviceMemory *c, \ DeviceMemory> *s) override; \ bool DoBlasRotg(Stream *stream, DeviceMemory> *a, \ DeviceMemory> *b, \ DeviceMemory *c, \ DeviceMemory> *s) override; \ bool DoBlasRotm(Stream *stream, uint64 elem_count, DeviceMemory *x, \ int incx, DeviceMemory *y, int incy, \ const DeviceMemory ¶m) override; \ bool DoBlasRotm(Stream *stream, uint64 elem_count, DeviceMemory *x, \ int incx, DeviceMemory *y, int incy, \ const DeviceMemory ¶m) override; \ bool DoBlasRotmg(Stream *stream, DeviceMemory *d1, \ DeviceMemory *d2, DeviceMemory *x1, \ const DeviceMemory &y1, DeviceMemory *param) \ override; \ bool DoBlasRotmg(Stream *stream, DeviceMemory *d1, \ DeviceMemory *d2, DeviceMemory *x1, \ const DeviceMemory &y1, \ DeviceMemory *param) override; \ bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha, \ DeviceMemory *x, int incx) override; \ bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha, \ DeviceMemory *x, int incx) override; \ bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha, \ DeviceMemory> *x, int incx) override; \ bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha, \ DeviceMemory> *x, int incx) override; \ bool DoBlasScal(Stream *stream, uint64 elem_count, \ std::complex alpha, \ DeviceMemory> *x, int incx) override; \ bool DoBlasScal(Stream *stream, uint64 elem_count, \ std::complex alpha, \ DeviceMemory> *x, int incx) override; \ bool DoBlasSwap(Stream *stream, uint64 elem_count, DeviceMemory *x, \ int incx, DeviceMemory *y, int incy) override; \ bool DoBlasSwap(Stream *stream, uint64 elem_count, DeviceMemory *x, \ int incx, DeviceMemory *y, int incy) override; \ bool DoBlasSwap(Stream *stream, uint64 elem_count, \ DeviceMemory> *x, int incx, \ DeviceMemory> *y, int incy) override; \ bool DoBlasSwap(Stream *stream, uint64 elem_count, \ DeviceMemory> *x, int incx, \ DeviceMemory> *y, int incy) override; \ bool DoBlasIamax(Stream *stream, uint64 elem_count, \ const DeviceMemory &x, int incx, \ DeviceMemory *result) override; \ bool DoBlasIamax(Stream *stream, uint64 elem_count, \ const DeviceMemory &x, int incx, \ DeviceMemory *result) override; \ bool DoBlasIamax(Stream *stream, uint64 elem_count, \ const DeviceMemory> &x, int incx, \ DeviceMemory *result) override; \ bool DoBlasIamax(Stream *stream, uint64 elem_count, \ const DeviceMemory> &x, int incx, \ DeviceMemory *result) override; \ bool DoBlasIamin(Stream *stream, uint64 elem_count, \ const DeviceMemory &x, int incx, \ DeviceMemory *result) override; \ bool DoBlasIamin(Stream *stream, uint64 elem_count, \ const DeviceMemory &x, int incx, \ DeviceMemory *result) override; \ bool DoBlasIamin(Stream *stream, uint64 elem_count, \ const DeviceMemory> &x, int incx, \ DeviceMemory *result) override; \ bool DoBlasIamin(Stream *stream, uint64 elem_count, \ const DeviceMemory> &x, int incx, \ DeviceMemory *result) override; \ bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ uint64 kl, uint64 ku, float alpha, \ const DeviceMemory &a, int lda, \ const DeviceMemory &x, int incx, float beta, \ DeviceMemory *y, int incy) override; \ bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ uint64 kl, uint64 ku, double alpha, \ const DeviceMemory &a, int lda, \ const DeviceMemory &x, int incx, double beta, \ DeviceMemory *y, int incy) override; \ bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ uint64 kl, uint64 ku, std::complex alpha, \ const DeviceMemory> &a, int lda, \ const DeviceMemory> &x, int incx, \ std::complex beta, \ DeviceMemory> *y, int incy) override; \ bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ uint64 kl, uint64 ku, std::complex alpha, \ const DeviceMemory> &a, int lda, \ const DeviceMemory> &x, int incx, \ std::complex beta, \ DeviceMemory> *y, int incy) override; \ bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ float alpha, const DeviceMemory &a, int lda, \ const DeviceMemory &x, int incx, float beta, \ DeviceMemory *y, int incy) override; \ bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ double alpha, const DeviceMemory &a, int lda, \ const DeviceMemory &x, int incx, double beta, \ DeviceMemory *y, int incy) override; \ bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ std::complex alpha, \ const DeviceMemory> &a, int lda, \ const DeviceMemory> &x, int incx, \ std::complex beta, \ DeviceMemory> *y, int incy) override; \ bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ std::complex alpha, \ const DeviceMemory> &a, int lda, \ const DeviceMemory> &x, int incx, \ std::complex beta, \ DeviceMemory> *y, int incy) override; \ bool DoBlasGemvWithProfiling( \ Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha, \ const DeviceMemory &a, int lda, const DeviceMemory &x, \ int incx, float beta, DeviceMemory *y, int incy, \ blas::ProfileResult *output_profile_result) override; \ bool DoBlasGemvWithProfiling( \ Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha, \ const DeviceMemory &a, int lda, const DeviceMemory &x, \ int incx, double beta, DeviceMemory *y, int incy, \ blas::ProfileResult *output_profile_result) override; \ bool DoBlasGemvWithProfiling( \ Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ std::complex alpha, const DeviceMemory> &a, \ int lda, const DeviceMemory> &x, int incx, \ std::complex beta, DeviceMemory> *y, \ int incy, blas::ProfileResult *output_profile_result) override; \ bool DoBlasGemvWithProfiling( \ Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ std::complex alpha, const DeviceMemory> &a, \ int lda, const DeviceMemory> &x, int incx, \ std::complex beta, DeviceMemory> *y, \ int incy, blas::ProfileResult *output_profile_result) override; \ bool DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha, \ const DeviceMemory &x, int incx, \ const DeviceMemory &y, int incy, \ DeviceMemory *a, int lda) override; \ bool DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha, \ const DeviceMemory &x, int incx, \ const DeviceMemory &y, int incy, \ DeviceMemory *a, int lda) override; \ bool DoBlasGerc(Stream *stream, uint64 m, uint64 n, \ std::complex alpha, \ const DeviceMemory> &x, int incx, \ const DeviceMemory> &y, int incy, \ DeviceMemory> *a, int lda) override; \ bool DoBlasGerc(Stream *stream, uint64 m, uint64 n, \ std::complex alpha, \ const DeviceMemory> &x, int incx, \ const DeviceMemory> &y, int incy, \ DeviceMemory> *a, int lda) override; \ bool DoBlasGeru(Stream *stream, uint64 m, uint64 n, \ std::complex alpha, \ const DeviceMemory> &x, int incx, \ const DeviceMemory> &y, int incy, \ DeviceMemory> *a, int lda) override; \ bool DoBlasGeru(Stream *stream, uint64 m, uint64 n, \ std::complex alpha, \ const DeviceMemory> &x, int incx, \ const DeviceMemory> &y, int incy, \ DeviceMemory> *a, int lda) override; \ bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k, \ std::complex alpha, \ const DeviceMemory> &a, int lda, \ const DeviceMemory> &x, int incx, \ std::complex beta, \ DeviceMemory> *y, int incy) override; \ bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k, \ std::complex alpha, \ const DeviceMemory> &a, int lda, \ const DeviceMemory> &x, int incx, \ std::complex beta, \ DeviceMemory> *y, int incy) override; \ bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n, \ std::complex alpha, \ const DeviceMemory> &a, int lda, \ const DeviceMemory> &x, int incx, \ std::complex beta, \ DeviceMemory> *y, int incy) override; \ bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n, \ std::complex alpha, \ const DeviceMemory> &a, int lda, \ const DeviceMemory> &x, int incx, \ std::complex beta, \ DeviceMemory> *y, int incy) override; \ bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \ const DeviceMemory> &x, int incx, \ DeviceMemory> *a, int lda) override; \ bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, \ double alpha, const DeviceMemory> &x, \ int incx, DeviceMemory> *a, int lda) \ override; \ bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n, \ std::complex alpha, \ const DeviceMemory> &x, int incx, \ const DeviceMemory> &y, int incy, \ DeviceMemory> *a, int lda) override; \ bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n, \ std::complex alpha, \ const DeviceMemory> &x, int incx, \ const DeviceMemory> &y, int incy, \ DeviceMemory> *a, int lda) override; \ bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n, \ std::complex alpha, \ const DeviceMemory> &ap, \ const DeviceMemory> &x, int incx, \ std::complex beta, \ DeviceMemory> *y, int incy) override; \ bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n, \ std::complex alpha, \ const DeviceMemory> &ap, \ const DeviceMemory> &x, int incx, \ std::complex beta, \ DeviceMemory> *y, int incy) override; \ bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \ const DeviceMemory> &x, int incx, \ DeviceMemory> *ap) override; \ bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, \ double alpha, const DeviceMemory> &x, \ int incx, DeviceMemory> *ap) override; \ bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n, \ std::complex alpha, \ const DeviceMemory> &x, int incx, \ const DeviceMemory> &y, int incy, \ DeviceMemory> *ap) override; \ bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n, \ std::complex alpha, \ const DeviceMemory> &x, int incx, \ const DeviceMemory> &y, int incy, \ DeviceMemory> *ap) override; \ bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k, \ float alpha, const DeviceMemory &a, int lda, \ const DeviceMemory &x, int incx, float beta, \ DeviceMemory *y, int incy) override; \ bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k, \ double alpha, const DeviceMemory &a, int lda, \ const DeviceMemory &x, int incx, double beta, \ DeviceMemory *y, int incy) override; \ bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n, \ float alpha, const DeviceMemory &ap, \ const DeviceMemory &x, int incx, float beta, \ DeviceMemory *y, int incy) override; \ bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n, \ double alpha, const DeviceMemory &ap, \ const DeviceMemory &x, int incx, double beta, \ DeviceMemory *y, int incy) override; \ bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \ const DeviceMemory &x, int incx, \ DeviceMemory *ap) override; \ bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, \ double alpha, const DeviceMemory &x, int incx, \ DeviceMemory *ap) override; \ bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n, \ float alpha, const DeviceMemory &x, int incx, \ const DeviceMemory &y, int incy, \ DeviceMemory *ap) override; \ bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n, \ double alpha, const DeviceMemory &x, int incx, \ const DeviceMemory &y, int incy, \ DeviceMemory *ap) override; \ bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n, \ float alpha, const DeviceMemory &a, int lda, \ const DeviceMemory &x, int incx, float beta, \ DeviceMemory *y, int incy) override; \ bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n, \ double alpha, const DeviceMemory &a, int lda, \ const DeviceMemory &x, int incx, double beta, \ DeviceMemory *y, int incy) override; \ bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \ const DeviceMemory &x, int incx, \ DeviceMemory *a, int lda) override; \ bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n, \ double alpha, const DeviceMemory &x, int incx, \ DeviceMemory *a, int lda) override; \ bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n, \ float alpha, const DeviceMemory &x, int incx, \ const DeviceMemory &y, int incy, \ DeviceMemory *a, int lda) override; \ bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n, \ double alpha, const DeviceMemory &x, int incx, \ const DeviceMemory &y, int incy, \ DeviceMemory *a, int lda) override; \ bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, \ blas::Transpose trans, blas::Diagonal diag, uint64 n, \ uint64 k, const DeviceMemory &a, int lda, \ DeviceMemory *x, int incx) override; \ bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, \ blas::Transpose trans, blas::Diagonal diag, uint64 n, \ uint64 k, const DeviceMemory &a, int lda, \ DeviceMemory *x, int incx) override; \ bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, \ blas::Transpose trans, blas::Diagonal diag, uint64 n, \ uint64 k, const DeviceMemory> &a, \ int lda, DeviceMemory> *x, int incx) \ override; \ bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, \ blas::Transpose trans, blas::Diagonal diag, uint64 n, \ uint64 k, const DeviceMemory> &a, \ int lda, DeviceMemory> *x, int incx) \ override; \ bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, \ blas::Transpose trans, blas::Diagonal diag, uint64 n, \ uint64 k, const DeviceMemory &a, int lda, \ DeviceMemory *x, int incx) override; \ bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, \ blas::Transpose trans, blas::Diagonal diag, uint64 n, \ uint64 k, const DeviceMemory &a, int lda, \ DeviceMemory *x, int incx) override; \ bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, \ blas::Transpose trans, blas::Diagonal diag, uint64 n, \ uint64 k, const DeviceMemory> &a, \ int lda, DeviceMemory> *x, int incx) \ override; \ bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, \ blas::Transpose trans, blas::Diagonal diag, uint64 n, \ uint64 k, const DeviceMemory> &a, \ int lda, DeviceMemory> *x, int incx) \ override; \ bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, \ blas::Transpose trans, blas::Diagonal diag, uint64 n, \ const DeviceMemory &ap, DeviceMemory *x, \ int incx) override; \ bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, \ blas::Transpose trans, blas::Diagonal diag, uint64 n, \ const DeviceMemory &ap, DeviceMemory *x, \ int incx) override; \ bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, \ blas::Transpose trans, blas::Diagonal diag, uint64 n, \ const DeviceMemory> &ap, \ DeviceMemory> *x, int incx) override; \ bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, \ blas::Transpose trans, blas::Diagonal diag, uint64 n, \ const DeviceMemory> &ap, \ DeviceMemory> *x, int incx) override; \ bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, \ blas::Transpose trans, blas::Diagonal diag, uint64 n, \ const DeviceMemory &ap, DeviceMemory *x, \ int incx) override; \ bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, \ blas::Transpose trans, blas::Diagonal diag, uint64 n, \ const DeviceMemory &ap, DeviceMemory *x, \ int incx) override; \ bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, \ blas::Transpose trans, blas::Diagonal diag, uint64 n, \ const DeviceMemory> &ap, \ DeviceMemory> *x, int incx) override; \ bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, \ blas::Transpose trans, blas::Diagonal diag, uint64 n, \ const DeviceMemory> &ap, \ DeviceMemory> *x, int incx) override; \ bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, \ blas::Transpose trans, blas::Diagonal diag, uint64 n, \ const DeviceMemory &a, int lda, \ DeviceMemory *x, int incx) override; \ bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, \ blas::Transpose trans, blas::Diagonal diag, uint64 n, \ const DeviceMemory &a, int lda, \ DeviceMemory *x, int incx) override; \ bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, \ blas::Transpose trans, blas::Diagonal diag, uint64 n, \ const DeviceMemory> &a, int lda, \ DeviceMemory> *x, int incx) override; \ bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, \ blas::Transpose trans, blas::Diagonal diag, uint64 n, \ const DeviceMemory> &a, int lda, \ DeviceMemory> *x, int incx) override; \ bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, \ blas::Transpose trans, blas::Diagonal diag, uint64 n, \ const DeviceMemory &a, int lda, \ DeviceMemory *x, int incx) override; \ bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, \ blas::Transpose trans, blas::Diagonal diag, uint64 n, \ const DeviceMemory &a, int lda, \ DeviceMemory *x, int incx) override; \ bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, \ blas::Transpose trans, blas::Diagonal diag, uint64 n, \ const DeviceMemory> &a, int lda, \ DeviceMemory> *x, int incx) override; \ bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, \ blas::Transpose trans, blas::Diagonal diag, uint64 n, \ const DeviceMemory> &a, int lda, \ DeviceMemory> *x, int incx) override; \ bool DoBlasGemm(Stream *stream, blas::Transpose transa, \ blas::Transpose transb, uint64 m, uint64 n, uint64 k, \ float alpha, const DeviceMemory &a, int lda, \ const DeviceMemory &b, int ldb, float beta, \ DeviceMemory *c, int ldc) override; \ bool DoBlasGemm(Stream *stream, blas::Transpose transa, \ blas::Transpose transb, uint64 m, uint64 n, uint64 k, \ float alpha, const DeviceMemory &a, int lda, \ const DeviceMemory &b, int ldb, float beta, \ DeviceMemory *c, int ldc) override; \ bool DoBlasGemm(Stream *stream, blas::Transpose transa, \ blas::Transpose transb, uint64 m, uint64 n, uint64 k, \ double alpha, const DeviceMemory &a, int lda, \ const DeviceMemory &b, int ldb, double beta, \ DeviceMemory *c, int ldc) override; \ bool DoBlasGemm(Stream *stream, blas::Transpose transa, \ blas::Transpose transb, uint64 m, uint64 n, uint64 k, \ std::complex alpha, \ const DeviceMemory> &a, int lda, \ const DeviceMemory> &b, int ldb, \ std::complex beta, \ DeviceMemory> *c, int ldc) override; \ bool DoBlasGemm(Stream *stream, blas::Transpose transa, \ blas::Transpose transb, uint64 m, uint64 n, uint64 k, \ std::complex alpha, \ const DeviceMemory> &a, int lda, \ const DeviceMemory> &b, int ldb, \ std::complex beta, \ DeviceMemory> *c, int ldc) override; \ bool DoBlasGemmWithProfiling( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64 m, uint64 n, uint64 k, float alpha, \ const DeviceMemory &a, int lda, \ const DeviceMemory &b, int ldb, float beta, \ DeviceMemory *c, int ldc, \ blas::ProfileResult *output_profile_result) override; \ bool DoBlasGemmWithProfiling( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory &a, \ int lda, const DeviceMemory &b, int ldb, float beta, \ DeviceMemory *c, int ldc, \ blas::ProfileResult *output_profile_result) override; \ bool DoBlasGemmWithProfiling( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64 m, uint64 n, uint64 k, double alpha, \ const DeviceMemory &a, int lda, const DeviceMemory &b, \ int ldb, double beta, DeviceMemory *c, int ldc, \ blas::ProfileResult *output_profile_result) override; \ bool DoBlasGemmWithProfiling( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64 m, uint64 n, uint64 k, std::complex alpha, \ const DeviceMemory> &a, int lda, \ const DeviceMemory> &b, int ldb, \ std::complex beta, DeviceMemory> *c, int ldc, \ blas::ProfileResult *output_profile_result) override; \ bool DoBlasGemmWithProfiling( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64 m, uint64 n, uint64 k, std::complex alpha, \ const DeviceMemory> &a, int lda, \ const DeviceMemory> &b, int ldb, \ std::complex beta, DeviceMemory> *c, \ int ldc, blas::ProfileResult *output_profile_result) override; \ bool GetBlasGemmAlgorithms(std::vector *out_algorithms) \ override; \ bool DoBlasGemmWithAlgorithm( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64 m, uint64 n, uint64 k, const HostOrDeviceScalar &alpha, \ const DeviceMemory &a, int lda, const DeviceMemory &b, \ int ldb, const HostOrDeviceScalar &beta, DeviceMemory *c, \ int ldc, blas::ComputationType computation_type, \ blas::AlgorithmType algorithm, \ blas::ProfileResult *output_profile_result) override; \ bool DoBlasGemmWithAlgorithm( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64 m, uint64 n, uint64 k, \ const HostOrDeviceScalar &alpha, \ const DeviceMemory &a, int lda, \ const DeviceMemory &b, int ldb, \ const HostOrDeviceScalar &beta, \ DeviceMemory *c, int ldc, \ blas::ComputationType computation_type, blas::AlgorithmType algorithm, \ blas::ProfileResult *output_profile_result) override; \ bool DoBlasGemmWithAlgorithm( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64 m, uint64 n, uint64 k, const HostOrDeviceScalar &alpha, \ const DeviceMemory &a, int lda, const DeviceMemory &b, \ int ldb, const HostOrDeviceScalar &beta, DeviceMemory *c, \ int ldc, blas::ComputationType computation_type, \ blas::AlgorithmType algorithm, \ blas::ProfileResult *output_profile_result) override; \ bool DoBlasGemmWithAlgorithm( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64 m, uint64 n, uint64 k, const HostOrDeviceScalar &alpha, \ const DeviceMemory &a, int lda, const DeviceMemory &b, \ int ldb, const HostOrDeviceScalar &beta, \ DeviceMemory *c, int ldc, \ blas::ComputationType computation_type, blas::AlgorithmType algorithm, \ blas::ProfileResult *output_profile_result) override; \ bool DoBlasGemmWithAlgorithm( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64 m, uint64 n, uint64 k, \ const HostOrDeviceScalar> &alpha, \ const DeviceMemory> &a, int lda, \ const DeviceMemory> &b, int ldb, \ const HostOrDeviceScalar> &beta, \ DeviceMemory> *c, int ldc, \ blas::ComputationType computation_type, blas::AlgorithmType algorithm, \ blas::ProfileResult *output_profile_result) override; \ bool DoBlasGemmWithAlgorithm( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64 m, uint64 n, uint64 k, \ const HostOrDeviceScalar> &alpha, \ const DeviceMemory> &a, int lda, \ const DeviceMemory> &b, int ldb, \ const HostOrDeviceScalar> &beta, \ DeviceMemory> *c, int ldc, \ blas::ComputationType computation_type, blas::AlgorithmType algorithm, \ blas::ProfileResult *output_profile_result) override; \ bool DoBlasGemmBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64 m, uint64 n, uint64 k, float alpha, \ const port::ArraySlice *> &a, int lda, \ const port::ArraySlice *> &b, int ldb, \ float beta, const port::ArraySlice *> &c, \ int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \ bool DoBlasGemmBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64 m, uint64 n, uint64 k, float alpha, \ const port::ArraySlice *> &a, int lda, \ const port::ArraySlice *> &b, int ldb, float beta, \ const port::ArraySlice *> &c, int ldc, \ int batch_count, ScratchAllocator *scratch_allocator) override; \ bool DoBlasGemmBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64 m, uint64 n, uint64 k, double alpha, \ const port::ArraySlice *> &a, int lda, \ const port::ArraySlice *> &b, int ldb, double beta, \ const port::ArraySlice *> &c, int ldc, \ int batch_count, ScratchAllocator *scratch_allocator) override; \ bool DoBlasGemmBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64 m, uint64 n, uint64 k, std::complex alpha, \ const port::ArraySlice> *> &a, int lda, \ const port::ArraySlice> *> &b, int ldb, \ std::complex beta, \ const port::ArraySlice> *> &c, int ldc, \ int batch_count, ScratchAllocator *scratch_allocator) override; \ bool DoBlasGemmBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64 m, uint64 n, uint64 k, std::complex alpha, \ const port::ArraySlice> *> &a, \ int lda, \ const port::ArraySlice> *> &b, \ int ldb, std::complex beta, \ const port::ArraySlice> *> &c, \ int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \ bool DoBlasGemmStridedBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64 m, uint64 n, uint64 k, float alpha, \ const DeviceMemory &a, int lda, int64 stride_a, \ const DeviceMemory &b, int ldb, int64 stride_b, float beta, \ DeviceMemory *c, int ldc, int64 stride_c, int batch_count); \ bool DoBlasGemmStridedBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory &a, \ int lda, int64 stride_a, const DeviceMemory &b, int ldb, \ int64 stride_b, float beta, DeviceMemory *c, int ldc, \ int64 stride_c, int batch_count); \ bool DoBlasGemmStridedBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64 m, uint64 n, uint64 k, double alpha, \ const DeviceMemory &a, int lda, int64 stride_a, \ const DeviceMemory &b, int ldb, int64 stride_b, double beta, \ DeviceMemory *c, int ldc, int64 stride_c, int batch_count); \ bool DoBlasGemmStridedBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64 m, uint64 n, uint64 k, std::complex alpha, \ const DeviceMemory> &a, int lda, int64 stride_a, \ const DeviceMemory> &b, int ldb, int64 stride_b, \ std::complex beta, DeviceMemory> *c, int ldc, \ int64 stride_c, int batch_count); \ bool DoBlasGemmStridedBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64 m, uint64 n, uint64 k, std::complex alpha, \ const DeviceMemory> &a, int lda, int64 stride_a, \ const DeviceMemory> &b, int ldb, int64 stride_b, \ std::complex beta, DeviceMemory> *c, \ int ldc, int64 stride_c, int batch_count); \ bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ uint64 m, uint64 n, std::complex alpha, \ const DeviceMemory> &a, int lda, \ const DeviceMemory> &b, int ldb, \ std::complex beta, \ DeviceMemory> *c, int ldc) override; \ bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ uint64 m, uint64 n, std::complex alpha, \ const DeviceMemory> &a, int lda, \ const DeviceMemory> &b, int ldb, \ std::complex beta, \ DeviceMemory> *c, int ldc) override; \ bool DoBlasHerk(Stream *stream, blas::UpperLower uplo, \ blas::Transpose trans, uint64 n, uint64 k, float alpha, \ const DeviceMemory> &a, int lda, \ float beta, DeviceMemory> *c, int ldc) \ override; \ bool DoBlasHerk(Stream *stream, blas::UpperLower uplo, \ blas::Transpose trans, uint64 n, uint64 k, double alpha, \ const DeviceMemory> &a, int lda, \ double beta, DeviceMemory> *c, int ldc) \ override; \ bool DoBlasHer2k( \ Stream *stream, blas::UpperLower uplo, blas::Transpose trans, uint64 n, \ uint64 k, std::complex alpha, \ const DeviceMemory> &a, int lda, \ const DeviceMemory> &b, int ldb, float beta, \ DeviceMemory> *c, int ldc) override; \ bool DoBlasHer2k( \ Stream *stream, blas::UpperLower uplo, blas::Transpose trans, uint64 n, \ uint64 k, std::complex alpha, \ const DeviceMemory> &a, int lda, \ const DeviceMemory> &b, int ldb, double beta, \ DeviceMemory> *c, int ldc) override; \ bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ uint64 m, uint64 n, float alpha, \ const DeviceMemory &a, int lda, \ const DeviceMemory &b, int ldb, float beta, \ DeviceMemory *c, int ldc) override; \ bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ uint64 m, uint64 n, double alpha, \ const DeviceMemory &a, int lda, \ const DeviceMemory &b, int ldb, double beta, \ DeviceMemory *c, int ldc) override; \ bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ uint64 m, uint64 n, std::complex alpha, \ const DeviceMemory> &a, int lda, \ const DeviceMemory> &b, int ldb, \ std::complex beta, \ DeviceMemory> *c, int ldc) override; \ bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ uint64 m, uint64 n, std::complex alpha, \ const DeviceMemory> &a, int lda, \ const DeviceMemory> &b, int ldb, \ std::complex beta, \ DeviceMemory> *c, int ldc) override; \ bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, \ blas::Transpose trans, uint64 n, uint64 k, float alpha, \ const DeviceMemory &a, int lda, float beta, \ DeviceMemory *c, int ldc) override; \ bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, \ blas::Transpose trans, uint64 n, uint64 k, double alpha, \ const DeviceMemory &a, int lda, double beta, \ DeviceMemory *c, int ldc) override; \ bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, \ blas::Transpose trans, uint64 n, uint64 k, \ std::complex alpha, \ const DeviceMemory> &a, int lda, \ std::complex beta, \ DeviceMemory> *c, int ldc) override; \ bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, \ blas::Transpose trans, uint64 n, uint64 k, \ std::complex alpha, \ const DeviceMemory> &a, int lda, \ std::complex beta, \ DeviceMemory> *c, int ldc) override; \ bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, \ blas::Transpose trans, uint64 n, uint64 k, float alpha, \ const DeviceMemory &a, int lda, \ const DeviceMemory &b, int ldb, float beta, \ DeviceMemory *c, int ldc) override; \ bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, \ blas::Transpose trans, uint64 n, uint64 k, double alpha, \ const DeviceMemory &a, int lda, \ const DeviceMemory &b, int ldb, double beta, \ DeviceMemory *c, int ldc) override; \ bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, \ blas::Transpose trans, uint64 n, uint64 k, \ std::complex alpha, \ const DeviceMemory> &a, int lda, \ const DeviceMemory> &b, int ldb, \ std::complex beta, \ DeviceMemory> *c, int ldc) override; \ bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, \ blas::Transpose trans, uint64 n, uint64 k, \ std::complex alpha, \ const DeviceMemory> &a, int lda, \ const DeviceMemory> &b, int ldb, \ std::complex beta, \ DeviceMemory> *c, int ldc) override; \ bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ blas::Transpose transa, blas::Diagonal diag, uint64 m, \ uint64 n, float alpha, const DeviceMemory &a, \ int lda, DeviceMemory *b, int ldb) override; \ bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ blas::Transpose transa, blas::Diagonal diag, uint64 m, \ uint64 n, double alpha, const DeviceMemory &a, \ int lda, DeviceMemory *b, int ldb) override; \ bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ blas::Transpose transa, blas::Diagonal diag, uint64 m, \ uint64 n, std::complex alpha, \ const DeviceMemory> &a, int lda, \ DeviceMemory> *b, int ldb) override; \ bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ blas::Transpose transa, blas::Diagonal diag, uint64 m, \ uint64 n, std::complex alpha, \ const DeviceMemory> &a, int lda, \ DeviceMemory> *b, int ldb) override; \ bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ blas::Transpose transa, blas::Diagonal diag, uint64 m, \ uint64 n, float alpha, const DeviceMemory &a, \ int lda, DeviceMemory *b, int ldb) override; \ bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ blas::Transpose transa, blas::Diagonal diag, uint64 m, \ uint64 n, double alpha, const DeviceMemory &a, \ int lda, DeviceMemory *b, int ldb) override; \ bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ blas::Transpose transa, blas::Diagonal diag, uint64 m, \ uint64 n, std::complex alpha, \ const DeviceMemory> &a, int lda, \ DeviceMemory> *b, int ldb) override; \ bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ blas::Transpose transa, blas::Diagonal diag, uint64 m, \ uint64 n, std::complex alpha, \ const DeviceMemory> &a, int lda, \ DeviceMemory> *b, int ldb) override; \ port::StatusOr> \ CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams ¶ms) override; \ port::StatusOr>> \ GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan, \ size_t max_workspace_size, \ int max_algorithm_count) override; \ bool DoBlasLtMatmul( \ Stream *stream, const blas::IBlasLtMatmulPlan *plan, \ const HostOrDeviceScalar &alpha, DeviceMemoryBase a, \ DeviceMemoryBase b, const HostOrDeviceScalar &beta, \ DeviceMemoryBase c, ScratchAllocator *scratch_allocator, \ const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias, \ blas::ProfileResult *output_profile_result) override; \ port::Status GetVersion(std::string *version) override; } // namespace blas } // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_BLAS_H_