1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Exposes the family of BLAS routines as pre-canned high performance calls for 17 // use in conjunction with the StreamExecutor abstraction. 18 // 19 // Note that this interface is optionally supported by platforms; see 20 // StreamExecutor::SupportsBlas() for details. 21 // 22 // This abstraction makes it simple to entrain BLAS operations on GPU data into 23 // a Stream -- users typically will not use this API directly, but will use the 24 // Stream builder methods to entrain these operations "under the hood". For 25 // example: 26 // 27 // DeviceMemory<float> x = stream_exec->AllocateArray<float>(1024); 28 // DeviceMemory<float> y = stream_exec->AllocateArray<float>(1024); 29 // // ... populate x and y ... 30 // Stream stream{stream_exec}; 31 // stream 32 // .Init() 33 // .ThenBlasAxpy(1024, 5.5, x, 1, &y, 1); 34 // SE_CHECK_OK(stream.BlockHostUntilDone()); 35 // 36 // By using stream operations in this manner the user can easily intermix custom 37 // kernel launches (via StreamExecutor::ThenLaunch()) with these pre-canned BLAS 38 // routines. 39 40 #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_BLAS_H_ 41 #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_BLAS_H_ 42 43 #include <complex> 44 #include <vector> 45 46 #include "tensorflow/compiler/xla/stream_executor/data_type.h" 47 #include "tensorflow/compiler/xla/stream_executor/device_memory.h" 48 #include "tensorflow/compiler/xla/stream_executor/dnn.pb.h" 49 #include "tensorflow/compiler/xla/stream_executor/lib/array_slice.h" 50 #include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" 51 #include "tensorflow/compiler/xla/stream_executor/platform/port.h" 52 53 namespace Eigen { 54 struct half; 55 } // namespace Eigen 56 57 namespace stream_executor { 58 59 class Stream; 60 class ScratchAllocator; 61 62 template <typename ElemT> 63 class DeviceMemory; 64 65 template <typename ElemT> 66 class HostOrDeviceScalar; 67 68 template <typename T> 69 using DeviceMemorySlice = port::ArraySlice<DeviceMemory<T> *>; // non-absl ok 70 71 namespace blas { 72 73 // Specifies whether the input matrix will be transposed or 74 // transposed+conjugated before any BLAS operations. 75 enum class Transpose { kNoTranspose, kTranspose, kConjugateTranspose }; 76 77 // Returns a name for t. 78 std::string TransposeString(Transpose t); 79 80 // Specifies whether the upper or lower triangular part of a 81 // symmetric/Hermitian matrix is used. 82 enum class UpperLower { kUpper, kLower }; 83 84 // Returns a name for ul. 85 std::string UpperLowerString(UpperLower ul); 86 87 // Specifies whether a matrix is unit triangular. 88 enum class Diagonal { kUnit, kNonUnit }; 89 90 // Returns a name for d. 91 std::string DiagonalString(Diagonal d); 92 93 // Specifies whether a Hermitian matrix appears on the left or right in 94 // operation. 95 enum class Side { kLeft, kRight }; 96 97 // Returns a name for s. 98 std::string SideString(Side s); 99 100 // Type with which intermediate computations of a blas routine are performed. 101 // 102 // Some blas calls can perform computations with a type that's different than 103 // the type of their inputs/outputs. This lets you e.g. multiply two matrices 104 // of int8s using float32s to store the matmul's intermediate values. 105 enum class ComputationType { 106 kF16, // 16-bit floating-point 107 kF32, // 32-bit floating-point 108 kF64, // 64-bit floating-point 109 kI32, // 32-bit integer 110 // The below values use float32 for accumulation, but allow the inputs and 111 // outputs to be downcast to a lower precision: 112 kF16AsF32, // Allow downcast to F16 precision. 113 kBF16AsF32, // Allow downcast to BF16 precision. 114 kTF32AsF32, // Allow downcast to TF32 precision. 115 }; 116 117 // Converts a ComputationType to a string. 118 std::string ComputationTypeString(ComputationType ty); 119 120 std::ostream &operator<<(std::ostream &os, ComputationType ty); 121 122 using dnn::DataType; 123 using dnn::ToDataType; 124 125 // Converts a ComputationType to a string. 126 std::string DataTypeString(DataType ty); 127 128 std::ostream &operator<<(std::ostream &os, DataType ty); 129 130 // Opaque identifier for an "algorithm" used by a blas routine. This functions 131 // as a hint to the blas library. 132 typedef int64_t AlgorithmType; 133 constexpr AlgorithmType kDefaultAlgorithm = -1; 134 constexpr AlgorithmType kDefaultBlasGemm = -2; 135 constexpr AlgorithmType kDefaultBlasGemv = -3; 136 constexpr AlgorithmType kNoAlgorithm = -4; 137 138 // blas uses -1 to represent the default algorithm. This happens to match up 139 // with the CUBLAS_GEMM_DFALT constant, so cuda_blas.cc is using static_cast 140 // to convert from AlgorithmType to cublasGemmAlgo_t, and uses a static_assert 141 // to ensure that this assumption does not break. 142 // If another blas implementation uses a different value for the default 143 // algorithm, then it needs to convert kDefaultGemmAlgo to that value 144 // (e.g. via a function called ToWhateverGemmAlgo). 145 constexpr AlgorithmType kDefaultGemmAlgo = -1; 146 147 // Describes the result of a performance experiment, usually timing the speed of 148 // a particular AlgorithmType. 149 // 150 // If the call we were benchmarking failed (a common occurrence; not all 151 // algorithms are valid for all calls), is_valid() will be false. 152 class ProfileResult { 153 public: is_valid()154 bool is_valid() const { return is_valid_; } set_is_valid(bool val)155 void set_is_valid(bool val) { is_valid_ = val; } algorithm()156 AlgorithmType algorithm() const { return algorithm_; } set_algorithm(AlgorithmType val)157 void set_algorithm(AlgorithmType val) { algorithm_ = val; } elapsed_time_in_ms()158 float elapsed_time_in_ms() const { return elapsed_time_in_ms_; } set_elapsed_time_in_ms(float val)159 void set_elapsed_time_in_ms(float val) { elapsed_time_in_ms_ = val; } 160 161 private: 162 bool is_valid_ = false; 163 AlgorithmType algorithm_ = kDefaultAlgorithm; 164 float elapsed_time_in_ms_ = std::numeric_limits<float>::max(); 165 }; 166 167 class AlgorithmConfig { 168 public: AlgorithmConfig()169 AlgorithmConfig() : algorithm_(kDefaultAlgorithm) {} AlgorithmConfig(AlgorithmType algorithm)170 explicit AlgorithmConfig(AlgorithmType algorithm) : algorithm_(algorithm) {} algorithm()171 AlgorithmType algorithm() const { return algorithm_; } set_algorithm(AlgorithmType val)172 void set_algorithm(AlgorithmType val) { algorithm_ = val; } 173 bool operator==(const AlgorithmConfig &other) const { 174 return this->algorithm_ == other.algorithm_; 175 } 176 bool operator!=(const AlgorithmConfig &other) const { 177 return !(*this == other); 178 } 179 std::string ToString() const; 180 181 private: 182 AlgorithmType algorithm_; 183 }; 184 185 // Opaque identifier specifying the precision to use in gemm calls. 186 typedef int64_t ComputePrecision; 187 constexpr ComputePrecision kDefaultComputePrecision = 0; 188 189 // This struct contains the metadata of a matrix, e.g., its base address and 190 // dimensions. 191 struct MatrixDescriptor { 192 DeviceMemoryBase data; 193 int64_t leading_dim_stride; 194 int64_t batch_stride; 195 Transpose transpose; 196 197 template <typename T> castMatrixDescriptor198 DeviceMemory<T> cast() const { 199 return DeviceMemory<T>(data); 200 } 201 }; 202 203 // BLAS support interface -- this can be derived from a GPU executor when the 204 // underlying platform has an BLAS library implementation available. See 205 // StreamExecutor::AsBlas(). 206 // 207 // Thread-hostile: CUDA associates a CUDA-context with a particular thread in 208 // the system. Any operation that a user attempts to perform by enqueueing BLAS 209 // operations on a thread not-associated with the CUDA-context has unknown 210 // behavior at the current time; see b/13176597 211 class BlasSupport { 212 public: ~BlasSupport()213 virtual ~BlasSupport() {} 214 215 // Performs a BLAS y <- ax+y operation. 216 virtual bool DoBlasAxpy(Stream *stream, uint64_t elem_count, float alpha, 217 const DeviceMemory<float> &x, int incx, 218 DeviceMemory<float> *y, int incy) = 0; 219 virtual bool DoBlasAxpy(Stream *stream, uint64_t elem_count, double alpha, 220 const DeviceMemory<double> &x, int incx, 221 DeviceMemory<double> *y, int incy) = 0; 222 virtual bool DoBlasAxpy(Stream *stream, uint64_t elem_count, 223 std::complex<float> alpha, 224 const DeviceMemory<std::complex<float>> &x, int incx, 225 DeviceMemory<std::complex<float>> *y, int incy) = 0; 226 virtual bool DoBlasAxpy(Stream *stream, uint64_t elem_count, 227 std::complex<double> alpha, 228 const DeviceMemory<std::complex<double>> &x, int incx, 229 DeviceMemory<std::complex<double>> *y, int incy) = 0; 230 231 // Copies vector to another vector: y <- x. 232 virtual bool DoBlasCopy(Stream *stream, uint64_t elem_count, 233 const DeviceMemory<float> &x, int incx, 234 DeviceMemory<float> *y, int incy) = 0; 235 virtual bool DoBlasCopy(Stream *stream, uint64_t elem_count, 236 const DeviceMemory<double> &x, int incx, 237 DeviceMemory<double> *y, int incy) = 0; 238 virtual bool DoBlasCopy(Stream *stream, uint64_t elem_count, 239 const DeviceMemory<std::complex<float>> &x, int incx, 240 DeviceMemory<std::complex<float>> *y, int incy) = 0; 241 virtual bool DoBlasCopy(Stream *stream, uint64_t elem_count, 242 const DeviceMemory<std::complex<double>> &x, int incx, 243 DeviceMemory<std::complex<double>> *y, int incy) = 0; 244 245 // Computes the product of a vector by a scalar: x <- a*x. 246 virtual bool DoBlasScal(Stream *stream, uint64_t elem_count, float alpha, 247 DeviceMemory<float> *x, int incx) = 0; 248 virtual bool DoBlasScal(Stream *stream, uint64_t elem_count, double alpha, 249 DeviceMemory<double> *x, int incx) = 0; 250 virtual bool DoBlasScal(Stream *stream, uint64_t elem_count, float alpha, 251 DeviceMemory<std::complex<float>> *x, int incx) = 0; 252 virtual bool DoBlasScal(Stream *stream, uint64_t elem_count, double alpha, 253 DeviceMemory<std::complex<double>> *x, int incx) = 0; 254 virtual bool DoBlasScal(Stream *stream, uint64_t elem_count, 255 std::complex<float> alpha, 256 DeviceMemory<std::complex<float>> *x, int incx) = 0; 257 virtual bool DoBlasScal(Stream *stream, uint64_t elem_count, 258 std::complex<double> alpha, 259 DeviceMemory<std::complex<double>> *x, int incx) = 0; 260 261 // Computes a matrix-vector product using a general matrix. 262 // 263 // y <- alpha * a * x + beta * y, 264 // or 265 // y <- alpha * a' * x + beta * y, 266 // or 267 // y <- alpha * conj(a') * x + beta * y, 268 // 269 // alpha and beta are scalars; a is an m-by-n general matrix; x is a vector 270 // with n(trans==kNoTranspose)/m(otherwise) elements; 271 // y is a vector with m(trans==kNoTranspose)/n(otherwise) elements. 272 virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m, 273 uint64_t n, float alpha, const DeviceMemory<float> &a, 274 int lda, const DeviceMemory<float> &x, int incx, 275 float beta, DeviceMemory<float> *y, int incy) = 0; 276 virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m, 277 uint64_t n, double alpha, 278 const DeviceMemory<double> &a, int lda, 279 const DeviceMemory<double> &x, int incx, double beta, 280 DeviceMemory<double> *y, int incy) = 0; 281 virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m, 282 uint64_t n, std::complex<float> alpha, 283 const DeviceMemory<std::complex<float>> &a, int lda, 284 const DeviceMemory<std::complex<float>> &x, int incx, 285 std::complex<float> beta, 286 DeviceMemory<std::complex<float>> *y, int incy) = 0; 287 virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m, 288 uint64_t n, std::complex<double> alpha, 289 const DeviceMemory<std::complex<double>> &a, int lda, 290 const DeviceMemory<std::complex<double>> &x, int incx, 291 std::complex<double> beta, 292 DeviceMemory<std::complex<double>> *y, int incy) = 0; 293 294 virtual bool DoBlasGemvWithProfiling( 295 Stream *stream, blas::Transpose trans, uint64_t m, uint64 n, float alpha, 296 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x, 297 int incx, float beta, DeviceMemory<float> *y, int incy, 298 ProfileResult *output_profile_result) = 0; 299 virtual bool DoBlasGemvWithProfiling( 300 Stream *stream, blas::Transpose trans, uint64_t m, uint64 n, double alpha, 301 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x, 302 int incx, double beta, DeviceMemory<double> *y, int incy, 303 ProfileResult *output_profile_result) = 0; 304 virtual bool DoBlasGemvWithProfiling( 305 Stream *stream, blas::Transpose trans, uint64_t m, uint64 n, 306 std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a, 307 int lda, const DeviceMemory<std::complex<float>> &x, int incx, 308 std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy, 309 ProfileResult *output_profile_result) = 0; 310 virtual bool DoBlasGemvWithProfiling( 311 Stream *stream, blas::Transpose trans, uint64_t m, uint64 n, 312 std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a, 313 int lda, const DeviceMemory<std::complex<double>> &x, int incx, 314 std::complex<double> beta, DeviceMemory<std::complex<double>> *y, 315 int incy, ProfileResult *output_profile_result) = 0; 316 317 // Computes a matrix-vector product using a symmetric band matrix. 318 // 319 // y <- alpha * a * x + beta * y, 320 // 321 // alpha and beta are scalars; a is an n-by-n symmetric band matrix, with k 322 // super-diagonals; x and y are n-element vectors. 323 virtual bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64_t n, 324 uint64_t k, float alpha, const DeviceMemory<float> &a, 325 int lda, const DeviceMemory<float> &x, int incx, 326 float beta, DeviceMemory<float> *y, int incy) = 0; 327 virtual bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64_t n, 328 uint64_t k, double alpha, 329 const DeviceMemory<double> &a, int lda, 330 const DeviceMemory<double> &x, int incx, double beta, 331 DeviceMemory<double> *y, int incy) = 0; 332 333 // Computes a matrix-matrix product with general matrices: 334 // 335 // c <- alpha * op(a) * op(b) + beta * c, 336 // 337 // op(X) is one of op(X) = X, or op(X) = X', or op(X) = conj(X'); alpha and 338 // beta are scalars; a, b, and c are matrices; op(a) is an m-by-k matrix; 339 // op(b) is a k-by-n matrix; c is an m-by-n matrix. 340 // 341 // Note: The half interface uses float precision internally; the version 342 // that uses half precision internally is not yet supported. There is no 343 // batched version of the half-precision interface. 344 // 345 // Alpha/beta type matches `dtype`, unless `dtype` is `Eigen::half`, in that 346 // case the expected alpha/beta type is `float`. 347 virtual port::Status DoBlasGemm(Stream *stream, blas::Transpose transa, 348 blas::Transpose transb, uint64_t m, uint64 n, 349 uint64_t k, DataType dtype, const void *alpha, 350 const DeviceMemoryBase &a, int lda, 351 const DeviceMemoryBase &b, int ldb, 352 const void *beta, DeviceMemoryBase *c, 353 int ldc, ComputePrecision precision) = 0; 354 355 virtual bool DoBlasGemmWithProfiling( 356 Stream *stream, blas::Transpose transa, blas::Transpose transb, 357 uint64_t m, uint64_t n, uint64 k, float alpha, 358 const DeviceMemory<Eigen::half> &a, int lda, 359 const DeviceMemory<Eigen::half> &b, int ldb, float beta, 360 DeviceMemory<Eigen::half> *c, int ldc, 361 ProfileResult *output_profile_result) = 0; 362 virtual bool DoBlasGemmWithProfiling( 363 Stream *stream, blas::Transpose transa, blas::Transpose transb, 364 uint64_t m, uint64_t n, uint64 k, float alpha, 365 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b, 366 int ldb, float beta, DeviceMemory<float> *c, int ldc, 367 ProfileResult *output_profile_result) = 0; 368 virtual bool DoBlasGemmWithProfiling( 369 Stream *stream, blas::Transpose transa, blas::Transpose transb, 370 uint64_t m, uint64_t n, uint64 k, double alpha, 371 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b, 372 int ldb, double beta, DeviceMemory<double> *c, int ldc, 373 ProfileResult *output_profile_result) = 0; 374 virtual bool DoBlasGemmWithProfiling( 375 Stream *stream, blas::Transpose transa, blas::Transpose transb, 376 uint64_t m, uint64_t n, uint64 k, std::complex<float> alpha, 377 const DeviceMemory<std::complex<float>> &a, int lda, 378 const DeviceMemory<std::complex<float>> &b, int ldb, 379 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, 380 ProfileResult *output_profile_result) = 0; 381 virtual bool DoBlasGemmWithProfiling( 382 Stream *stream, blas::Transpose transa, blas::Transpose transb, 383 uint64_t m, uint64_t n, uint64 k, std::complex<double> alpha, 384 const DeviceMemory<std::complex<double>> &a, int lda, 385 const DeviceMemory<std::complex<double>> &b, int ldb, 386 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, 387 ProfileResult *output_profile_result) = 0; 388 389 // Gets a list of supported algorithms for DoBlasGemmWithAlgorithm. 390 virtual bool GetBlasGemmAlgorithms( 391 Stream *stream, std::vector<AlgorithmType> *out_algorithms) = 0; 392 393 // Like DoBlasGemm, but accepts an algorithm and an compute type. 394 // 395 // The compute type lets you say (e.g.) that the inputs and outputs are 396 // Eigen::halfs, but you want the internal computations to be done with 397 // float32 precision. 398 // 399 // If output_profile_result is not null, a failure here does not put the 400 // stream in a failure state. Instead, success/failure is indicated by 401 // output_profile_result->is_valid(). This lets you use this function for 402 // choosing the best algorithm among many (some of which may fail) without 403 // creating a new Stream for each attempt. 404 virtual port::Status DoBlasGemmWithAlgorithm( 405 Stream *stream, blas::Transpose transa, blas::Transpose transb, 406 uint64_t m, uint64_t n, uint64 k, const void *alpha, 407 const DeviceMemoryBase &a, DataType type_a, int lda, 408 const DeviceMemoryBase &b, DataType type_b, int ldb, const void *beta, 409 DeviceMemoryBase *c, DataType type_c, int ldc, 410 ComputationType computation_type, AlgorithmType algorithm, 411 ProfileResult *output_profile_result) = 0; 412 413 virtual port::Status DoBlasGemmStridedBatchedWithAlgorithm( 414 Stream *stream, blas::Transpose transa, blas::Transpose transb, 415 uint64_t m, uint64_t n, uint64 k, const void *alpha, 416 const DeviceMemoryBase &a, DataType type_a, int lda, int64_t stride_a, 417 const DeviceMemoryBase &b, DataType type_b, int ldb, int64_t stride_b, 418 const void *beta, DeviceMemoryBase *c, DataType type_c, int ldc, 419 int64_t stride_c, int batch_count, ComputationType computation_type, 420 AlgorithmType algorithm, ProfileResult *output_profile_result) = 0; 421 422 // Computes a batch of matrix-matrix product with general matrices. 423 // This is a batched version of DoBlasGemm. 424 // The batched GEMM computes matrix product for each input/output in a, b, 425 // and c, which contain batch_count DeviceMemory objects. 426 virtual bool DoBlasGemmBatched( 427 Stream *stream, blas::Transpose transa, blas::Transpose transb, 428 uint64_t m, uint64_t n, uint64 k, float alpha, 429 const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, // non-absl ok 430 int lda, 431 const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, // non-absl ok 432 int ldb, float beta, 433 const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, // non-absl ok 434 int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0; 435 virtual bool DoBlasGemmBatched( 436 Stream *stream, blas::Transpose transa, blas::Transpose transb, 437 uint64_t m, uint64_t n, uint64 k, float alpha, 438 const port::ArraySlice<DeviceMemory<float> *> &a, int lda, // non-absl ok 439 const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, // non-absl ok 440 float beta, 441 const port::ArraySlice<DeviceMemory<float> *> &c, // non-absl ok 442 int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0; 443 virtual bool DoBlasGemmBatched( 444 Stream *stream, blas::Transpose transa, blas::Transpose transb, 445 uint64_t m, uint64_t n, uint64 k, double alpha, 446 const port::ArraySlice<DeviceMemory<double> *> &a, // non-absl ok 447 int lda, 448 const port::ArraySlice<DeviceMemory<double> *> &b, // non-absl ok 449 int ldb, double beta, 450 const port::ArraySlice<DeviceMemory<double> *> &c, // non-absl ok 451 int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0; 452 virtual bool DoBlasGemmBatched( 453 Stream *stream, blas::Transpose transa, blas::Transpose transb, 454 uint64_t m, uint64_t n, uint64 k, std::complex<float> alpha, 455 const DeviceMemorySlice<std::complex<float>> &a, int lda, 456 const DeviceMemorySlice<std::complex<float>> &b, int ldb, 457 std::complex<float> beta, const DeviceMemorySlice<std::complex<float>> &c, 458 int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0; 459 virtual bool DoBlasGemmBatched( 460 Stream *stream, blas::Transpose transa, blas::Transpose transb, 461 uint64_t m, uint64_t n, uint64 k, std::complex<double> alpha, 462 const DeviceMemorySlice<std::complex<double>> &a, int lda, 463 const DeviceMemorySlice<std::complex<double>> &b, int ldb, 464 std::complex<double> beta, 465 const DeviceMemorySlice<std::complex<double>> &c, int ldc, 466 int batch_count, ScratchAllocator *scratch_allocator) = 0; 467 468 // Batched gemm with strides instead of pointer arrays. 469 virtual port::Status DoBlasGemmStridedBatched( 470 Stream *stream, blas::Transpose transa, blas::Transpose transb, 471 uint64_t m, uint64_t n, uint64 k, DataType dtype, const void *alpha, 472 const DeviceMemoryBase &a, int lda, int64_t stride_a, 473 const DeviceMemoryBase &b, int ldb, int64_t stride_b, const void *beta, 474 DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count) = 0; 475 476 // Solves a triangular matrix equation. 477 // 478 // op(a) * x = alpha * b, 479 // or 480 // x * op(a) = alpha * b 481 // 482 // alpha is a scalar; x and b are m-by-n matrices; a is a unit, or non-unit, 483 // upper or lower triangular matrix; op(a) is one of op(a) = a, or op(a) = a', 484 // or op(a) = conj(a'). 485 virtual bool DoBlasTrsm(Stream *stream, blas::Side side, 486 blas::UpperLower uplo, blas::Transpose transa, 487 blas::Diagonal diag, uint64_t m, uint64 n, 488 float alpha, const DeviceMemory<float> &a, int lda, 489 DeviceMemory<float> *b, int ldb) = 0; 490 virtual bool DoBlasTrsm(Stream *stream, blas::Side side, 491 blas::UpperLower uplo, blas::Transpose transa, 492 blas::Diagonal diag, uint64_t m, uint64 n, 493 double alpha, const DeviceMemory<double> &a, int lda, 494 DeviceMemory<double> *b, int ldb) = 0; 495 virtual bool DoBlasTrsm(Stream *stream, blas::Side side, 496 blas::UpperLower uplo, blas::Transpose transa, 497 blas::Diagonal diag, uint64_t m, uint64 n, 498 std::complex<float> alpha, 499 const DeviceMemory<std::complex<float>> &a, int lda, 500 DeviceMemory<std::complex<float>> *b, int ldb) = 0; 501 virtual bool DoBlasTrsm(Stream *stream, blas::Side side, 502 blas::UpperLower uplo, blas::Transpose transa, 503 blas::Diagonal diag, uint64_t m, uint64 n, 504 std::complex<double> alpha, 505 const DeviceMemory<std::complex<double>> &a, int lda, 506 DeviceMemory<std::complex<double>> *b, int ldb) = 0; 507 508 // Same as DoBlasTrsm, but operates over a list of a's and b's. The lists 509 // `as` and `bs` must have the same length. 510 virtual bool DoBlasTrsmBatched(Stream *stream, blas::Side side, 511 blas::UpperLower uplo, blas::Transpose transa, 512 blas::Diagonal diag, uint64_t m, uint64 n, 513 float alpha, const DeviceMemory<float *> &as, 514 int lda, DeviceMemory<float *> *bs, int ldb, 515 int batch_count) = 0; 516 virtual bool DoBlasTrsmBatched(Stream *stream, blas::Side side, 517 blas::UpperLower uplo, blas::Transpose transa, 518 blas::Diagonal diag, uint64_t m, uint64 n, 519 double alpha, const DeviceMemory<double *> &as, 520 int lda, DeviceMemory<double *> *bs, int ldb, 521 int batch_count) = 0; 522 virtual bool DoBlasTrsmBatched(Stream *stream, blas::Side side, 523 blas::UpperLower uplo, blas::Transpose transa, 524 blas::Diagonal diag, uint64_t m, uint64 n, 525 std::complex<float> alpha, 526 const DeviceMemory<std::complex<float> *> &as, 527 int lda, 528 DeviceMemory<std::complex<float> *> *bs, 529 int ldb, int batch_count) = 0; 530 virtual bool DoBlasTrsmBatched(Stream *stream, blas::Side side, 531 blas::UpperLower uplo, blas::Transpose transa, 532 blas::Diagonal diag, uint64_t m, uint64 n, 533 std::complex<double> alpha, 534 const DeviceMemory<std::complex<double> *> &as, 535 int lda, 536 DeviceMemory<std::complex<double> *> *bs, 537 int ldb, int batch_count) = 0; 538 539 virtual port::Status GetVersion(std::string *version) = 0; 540 541 protected: BlasSupport()542 BlasSupport() {} 543 544 private: 545 SE_DISALLOW_COPY_AND_ASSIGN(BlasSupport); 546 }; 547 548 // Macro used to quickly declare overrides for abstract virtuals in the 549 // BlasSupport base class. 550 #define TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES \ 551 bool DoBlasAxpy(Stream *stream, uint64_t elem_count, float alpha, \ 552 const DeviceMemory<float> &x, int incx, \ 553 DeviceMemory<float> *y, int incy) override; \ 554 bool DoBlasAxpy(Stream *stream, uint64_t elem_count, double alpha, \ 555 const DeviceMemory<double> &x, int incx, \ 556 DeviceMemory<double> *y, int incy) override; \ 557 bool DoBlasAxpy(Stream *stream, uint64_t elem_count, \ 558 std::complex<float> alpha, \ 559 const DeviceMemory<std::complex<float>> &x, int incx, \ 560 DeviceMemory<std::complex<float>> *y, int incy) override; \ 561 bool DoBlasAxpy(Stream *stream, uint64_t elem_count, \ 562 std::complex<double> alpha, \ 563 const DeviceMemory<std::complex<double>> &x, int incx, \ 564 DeviceMemory<std::complex<double>> *y, int incy) override; \ 565 bool DoBlasCopy(Stream *stream, uint64_t elem_count, \ 566 const DeviceMemory<float> &x, int incx, \ 567 DeviceMemory<float> *y, int incy) override; \ 568 bool DoBlasCopy(Stream *stream, uint64_t elem_count, \ 569 const DeviceMemory<double> &x, int incx, \ 570 DeviceMemory<double> *y, int incy) override; \ 571 bool DoBlasCopy(Stream *stream, uint64_t elem_count, \ 572 const DeviceMemory<std::complex<float>> &x, int incx, \ 573 DeviceMemory<std::complex<float>> *y, int incy) override; \ 574 bool DoBlasCopy(Stream *stream, uint64_t elem_count, \ 575 const DeviceMemory<std::complex<double>> &x, int incx, \ 576 DeviceMemory<std::complex<double>> *y, int incy) override; \ 577 bool DoBlasScal(Stream *stream, uint64_t elem_count, float alpha, \ 578 DeviceMemory<float> *x, int incx) override; \ 579 bool DoBlasScal(Stream *stream, uint64_t elem_count, double alpha, \ 580 DeviceMemory<double> *x, int incx) override; \ 581 bool DoBlasScal(Stream *stream, uint64_t elem_count, float alpha, \ 582 DeviceMemory<std::complex<float>> *x, int incx) override; \ 583 bool DoBlasScal(Stream *stream, uint64_t elem_count, double alpha, \ 584 DeviceMemory<std::complex<double>> *x, int incx) override; \ 585 bool DoBlasScal(Stream *stream, uint64_t elem_count, \ 586 std::complex<float> alpha, \ 587 DeviceMemory<std::complex<float>> *x, int incx) override; \ 588 bool DoBlasScal(Stream *stream, uint64_t elem_count, \ 589 std::complex<double> alpha, \ 590 DeviceMemory<std::complex<double>> *x, int incx) override; \ 591 bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m, uint64 n, \ 592 float alpha, const DeviceMemory<float> &a, int lda, \ 593 const DeviceMemory<float> &x, int incx, float beta, \ 594 DeviceMemory<float> *y, int incy) override; \ 595 bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m, uint64 n, \ 596 double alpha, const DeviceMemory<double> &a, int lda, \ 597 const DeviceMemory<double> &x, int incx, double beta, \ 598 DeviceMemory<double> *y, int incy) override; \ 599 bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m, uint64 n, \ 600 std::complex<float> alpha, \ 601 const DeviceMemory<std::complex<float>> &a, int lda, \ 602 const DeviceMemory<std::complex<float>> &x, int incx, \ 603 std::complex<float> beta, \ 604 DeviceMemory<std::complex<float>> *y, int incy) override; \ 605 bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m, uint64 n, \ 606 std::complex<double> alpha, \ 607 const DeviceMemory<std::complex<double>> &a, int lda, \ 608 const DeviceMemory<std::complex<double>> &x, int incx, \ 609 std::complex<double> beta, \ 610 DeviceMemory<std::complex<double>> *y, int incy) override; \ 611 bool DoBlasGemvWithProfiling( \ 612 Stream *stream, blas::Transpose trans, uint64_t m, uint64 n, \ 613 float alpha, const DeviceMemory<float> &a, int lda, \ 614 const DeviceMemory<float> &x, int incx, float beta, \ 615 DeviceMemory<float> *y, int incy, \ 616 blas::ProfileResult *output_profile_result) override; \ 617 bool DoBlasGemvWithProfiling( \ 618 Stream *stream, blas::Transpose trans, uint64_t m, uint64 n, \ 619 double alpha, const DeviceMemory<double> &a, int lda, \ 620 const DeviceMemory<double> &x, int incx, double beta, \ 621 DeviceMemory<double> *y, int incy, \ 622 blas::ProfileResult *output_profile_result) override; \ 623 bool DoBlasGemvWithProfiling( \ 624 Stream *stream, blas::Transpose trans, uint64_t m, uint64 n, \ 625 std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a, \ 626 int lda, const DeviceMemory<std::complex<float>> &x, int incx, \ 627 std::complex<float> beta, DeviceMemory<std::complex<float>> *y, \ 628 int incy, blas::ProfileResult *output_profile_result) override; \ 629 bool DoBlasGemvWithProfiling( \ 630 Stream *stream, blas::Transpose trans, uint64_t m, uint64 n, \ 631 std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a, \ 632 int lda, const DeviceMemory<std::complex<double>> &x, int incx, \ 633 std::complex<double> beta, DeviceMemory<std::complex<double>> *y, \ 634 int incy, blas::ProfileResult *output_profile_result) override; \ 635 bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64_t n, uint64 k, \ 636 float alpha, const DeviceMemory<float> &a, int lda, \ 637 const DeviceMemory<float> &x, int incx, float beta, \ 638 DeviceMemory<float> *y, int incy) override; \ 639 bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64_t n, uint64 k, \ 640 double alpha, const DeviceMemory<double> &a, int lda, \ 641 const DeviceMemory<double> &x, int incx, double beta, \ 642 DeviceMemory<double> *y, int incy) override; \ 643 port::Status DoBlasGemm( \ 644 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 645 uint64_t m, uint64 n, uint64 k, blas::DataType dtype, const void *alpha, \ 646 const DeviceMemoryBase &a, int lda, const DeviceMemoryBase &b, int ldb, \ 647 const void *beta, DeviceMemoryBase *c, int ldc, \ 648 blas::ComputePrecision precision) override; \ 649 bool DoBlasGemmWithProfiling( \ 650 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 651 uint64_t m, uint64 n, uint64 k, float alpha, \ 652 const DeviceMemory<Eigen::half> &a, int lda, \ 653 const DeviceMemory<Eigen::half> &b, int ldb, float beta, \ 654 DeviceMemory<Eigen::half> *c, int ldc, \ 655 blas::ProfileResult *output_profile_result) override; \ 656 bool DoBlasGemmWithProfiling( \ 657 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 658 uint64_t m, uint64 n, uint64 k, float alpha, \ 659 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b, \ 660 int ldb, float beta, DeviceMemory<float> *c, int ldc, \ 661 blas::ProfileResult *output_profile_result) override; \ 662 bool DoBlasGemmWithProfiling( \ 663 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 664 uint64_t m, uint64 n, uint64 k, double alpha, \ 665 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b, \ 666 int ldb, double beta, DeviceMemory<double> *c, int ldc, \ 667 blas::ProfileResult *output_profile_result) override; \ 668 bool DoBlasGemmWithProfiling( \ 669 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 670 uint64_t m, uint64 n, uint64 k, std::complex<float> alpha, \ 671 const DeviceMemory<std::complex<float>> &a, int lda, \ 672 const DeviceMemory<std::complex<float>> &b, int ldb, \ 673 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, \ 674 blas::ProfileResult *output_profile_result) override; \ 675 bool DoBlasGemmWithProfiling( \ 676 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 677 uint64_t m, uint64 n, uint64 k, std::complex<double> alpha, \ 678 const DeviceMemory<std::complex<double>> &a, int lda, \ 679 const DeviceMemory<std::complex<double>> &b, int ldb, \ 680 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, \ 681 int ldc, blas::ProfileResult *output_profile_result) override; \ 682 bool GetBlasGemmAlgorithms(Stream *stream, \ 683 std::vector<blas::AlgorithmType> *out_algorithms) \ 684 override; \ 685 port::Status DoBlasGemmWithAlgorithm( \ 686 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 687 uint64_t m, uint64 n, uint64 k, const void *alpha, \ 688 const DeviceMemoryBase &a, blas::DataType type_a, int lda, \ 689 const DeviceMemoryBase &b, blas::DataType type_b, int ldb, \ 690 const void *beta, DeviceMemoryBase *c, blas::DataType type_c, int ldc, \ 691 blas::ComputationType computation_type, blas::AlgorithmType algorithm, \ 692 blas::ProfileResult *output_profile_result) override; \ 693 bool DoBlasGemmBatched( \ 694 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 695 uint64_t m, uint64 n, uint64 k, float alpha, \ 696 const DeviceMemorySlice<Eigen::half> &a, int lda, \ 697 const DeviceMemorySlice<Eigen::half> &b, int ldb, float beta, \ 698 const DeviceMemorySlice<Eigen::half> &c, int ldc, int batch_count, \ 699 ScratchAllocator *scratch_allocator) override; \ 700 bool DoBlasGemmBatched( \ 701 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 702 uint64_t m, uint64 n, uint64 k, float alpha, \ 703 const DeviceMemorySlice<float> &a, int lda, \ 704 const DeviceMemorySlice<float> &b, int ldb, float beta, \ 705 const DeviceMemorySlice<float> &c, int ldc, int batch_count, \ 706 ScratchAllocator *scratch_allocator) override; \ 707 bool DoBlasGemmBatched( \ 708 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 709 uint64_t m, uint64 n, uint64 k, double alpha, \ 710 const DeviceMemorySlice<double> &a, int lda, \ 711 const DeviceMemorySlice<double> &b, int ldb, double beta, \ 712 const DeviceMemorySlice<double> &c, int ldc, int batch_count, \ 713 ScratchAllocator *scratch_allocator) override; \ 714 bool DoBlasGemmBatched( \ 715 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 716 uint64_t m, uint64 n, uint64 k, std::complex<float> alpha, \ 717 const DeviceMemorySlice<std::complex<float>> &a, int lda, \ 718 const DeviceMemorySlice<std::complex<float>> &b, int ldb, \ 719 std::complex<float> beta, \ 720 const DeviceMemorySlice<std::complex<float>> &c, int ldc, \ 721 int batch_count, ScratchAllocator *scratch_allocator) override; \ 722 bool DoBlasGemmBatched( \ 723 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 724 uint64_t m, uint64 n, uint64 k, std::complex<double> alpha, \ 725 const DeviceMemorySlice<std::complex<double>> &a, int lda, \ 726 const DeviceMemorySlice<std::complex<double>> &b, int ldb, \ 727 std::complex<double> beta, \ 728 const DeviceMemorySlice<std::complex<double>> &c, int ldc, \ 729 int batch_count, ScratchAllocator *scratch_allocator) override; \ 730 port::Status DoBlasGemmStridedBatched( \ 731 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 732 uint64_t m, uint64 n, uint64 k, blas::DataType dtype, const void *alpha, \ 733 const DeviceMemoryBase &a, int lda, int64_t stride_a, \ 734 const DeviceMemoryBase &b, int ldb, int64_t stride_b, const void *beta, \ 735 DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count); \ 736 port::Status DoBlasGemmStridedBatchedWithAlgorithm( \ 737 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 738 uint64_t m, uint64 n, uint64 k, const void *alpha, \ 739 const DeviceMemoryBase &a, blas::DataType type_a, int lda, \ 740 int64_t stride_a, const DeviceMemoryBase &b, blas::DataType type_b, \ 741 int ldb, int64_t stride_b, const void *beta, DeviceMemoryBase *c, \ 742 blas::DataType type_c, int ldc, int64_t stride_c, int batch_count, \ 743 blas::ComputationType computation_type, blas::AlgorithmType algorithm, \ 744 blas::ProfileResult *output_profile_result) override; \ 745 bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 746 blas::Transpose transa, blas::Diagonal diag, uint64_t m, \ 747 uint64_t n, float alpha, const DeviceMemory<float> &a, \ 748 int lda, DeviceMemory<float> *b, int ldb) override; \ 749 bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 750 blas::Transpose transa, blas::Diagonal diag, uint64_t m, \ 751 uint64_t n, double alpha, const DeviceMemory<double> &a, \ 752 int lda, DeviceMemory<double> *b, int ldb) override; \ 753 bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 754 blas::Transpose transa, blas::Diagonal diag, uint64_t m, \ 755 uint64_t n, std::complex<float> alpha, \ 756 const DeviceMemory<std::complex<float>> &a, int lda, \ 757 DeviceMemory<std::complex<float>> *b, int ldb) override; \ 758 bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 759 blas::Transpose transa, blas::Diagonal diag, uint64_t m, \ 760 uint64_t n, std::complex<double> alpha, \ 761 const DeviceMemory<std::complex<double>> &a, int lda, \ 762 DeviceMemory<std::complex<double>> *b, int ldb) override; \ 763 bool DoBlasTrsmBatched( \ 764 Stream *stream, blas::Side side, blas::UpperLower uplo, \ 765 blas::Transpose transa, blas::Diagonal diag, uint64_t m, uint64 n, \ 766 float alpha, const DeviceMemory<float *> &as, int lda, \ 767 DeviceMemory<float *> *bs, int ldb, int batch_count) override; \ 768 bool DoBlasTrsmBatched( \ 769 Stream *stream, blas::Side side, blas::UpperLower uplo, \ 770 blas::Transpose transa, blas::Diagonal diag, uint64_t m, uint64 n, \ 771 double alpha, const DeviceMemory<double *> &as, int lda, \ 772 DeviceMemory<double *> *bs, int ldb, int batch_count) override; \ 773 bool DoBlasTrsmBatched(Stream *stream, blas::Side side, \ 774 blas::UpperLower uplo, blas::Transpose transa, \ 775 blas::Diagonal diag, uint64_t m, uint64 n, \ 776 std::complex<float> alpha, \ 777 const DeviceMemory<std::complex<float> *> &as, \ 778 int lda, DeviceMemory<std::complex<float> *> *bs, \ 779 int ldb, int batch_count) override; \ 780 bool DoBlasTrsmBatched(Stream *stream, blas::Side side, \ 781 blas::UpperLower uplo, blas::Transpose transa, \ 782 blas::Diagonal diag, uint64_t m, uint64 n, \ 783 std::complex<double> alpha, \ 784 const DeviceMemory<std::complex<double> *> &as, \ 785 int lda, DeviceMemory<std::complex<double> *> *bs, \ 786 int ldb, int batch_count) override; \ 787 port::Status GetVersion(std::string *version) override; 788 789 } // namespace blas 790 } // namespace stream_executor 791 792 #endif // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_BLAS_H_ 793