• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Exposes the family of BLAS routines as pre-canned high performance calls for
17 // use in conjunction with the StreamExecutor abstraction.
18 //
19 // Note that this interface is optionally supported by platforms; see
20 // StreamExecutor::SupportsBlas() for details.
21 //
22 // This abstraction makes it simple to entrain BLAS operations on GPU data into
23 // a Stream -- users typically will not use this API directly, but will use the
24 // Stream builder methods to entrain these operations "under the hood". For
25 // example:
26 //
27 //  DeviceMemory<float> x = stream_exec->AllocateArray<float>(1024);
28 //  DeviceMemory<float> y = stream_exec->AllocateArray<float>(1024);
29 //  // ... populate x and y ...
30 //  Stream stream{stream_exec};
31 //  stream
32 //    .Init()
33 //    .ThenBlasAxpy(1024, 5.5, x, 1, &y, 1);
34 //  SE_CHECK_OK(stream.BlockHostUntilDone());
35 //
36 // By using stream operations in this manner the user can easily intermix custom
37 // kernel launches (via StreamExecutor::ThenLaunch()) with these pre-canned BLAS
38 // routines.
39 
40 #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_BLAS_H_
41 #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_BLAS_H_
42 
43 #include <complex>
44 #include <vector>
45 
46 #include "tensorflow/compiler/xla/stream_executor/data_type.h"
47 #include "tensorflow/compiler/xla/stream_executor/device_memory.h"
48 #include "tensorflow/compiler/xla/stream_executor/dnn.pb.h"
49 #include "tensorflow/compiler/xla/stream_executor/lib/array_slice.h"
50 #include "tensorflow/compiler/xla/stream_executor/lib/statusor.h"
51 #include "tensorflow/compiler/xla/stream_executor/platform/port.h"
52 
53 namespace Eigen {
54 struct half;
55 }  // namespace Eigen
56 
57 namespace stream_executor {
58 
59 class Stream;
60 class ScratchAllocator;
61 
62 template <typename ElemT>
63 class DeviceMemory;
64 
65 template <typename ElemT>
66 class HostOrDeviceScalar;
67 
68 template <typename T>
69 using DeviceMemorySlice = port::ArraySlice<DeviceMemory<T> *>;  // non-absl ok
70 
71 namespace blas {
72 
73 // Specifies whether the input matrix will be transposed or
74 // transposed+conjugated before any BLAS operations.
75 enum class Transpose { kNoTranspose, kTranspose, kConjugateTranspose };
76 
77 // Returns a name for t.
78 std::string TransposeString(Transpose t);
79 
80 // Specifies whether the upper or lower triangular part of a
81 // symmetric/Hermitian matrix is used.
82 enum class UpperLower { kUpper, kLower };
83 
84 // Returns a name for ul.
85 std::string UpperLowerString(UpperLower ul);
86 
87 // Specifies whether a matrix is unit triangular.
88 enum class Diagonal { kUnit, kNonUnit };
89 
90 // Returns a name for d.
91 std::string DiagonalString(Diagonal d);
92 
93 // Specifies whether a Hermitian matrix appears on the left or right in
94 // operation.
95 enum class Side { kLeft, kRight };
96 
97 // Returns a name for s.
98 std::string SideString(Side s);
99 
100 // Type with which intermediate computations of a blas routine are performed.
101 //
102 // Some blas calls can perform computations with a type that's different than
103 // the type of their inputs/outputs.  This lets you e.g. multiply two matrices
104 // of int8s using float32s to store the matmul's intermediate values.
105 enum class ComputationType {
106   kF16,  // 16-bit floating-point
107   kF32,  // 32-bit floating-point
108   kF64,  // 64-bit floating-point
109   kI32,  // 32-bit integer
110   // The below values use float32 for accumulation, but allow the inputs and
111   // outputs to be downcast to a lower precision:
112   kF16AsF32,   // Allow downcast to F16 precision.
113   kBF16AsF32,  // Allow downcast to BF16 precision.
114   kTF32AsF32,  // Allow downcast to TF32 precision.
115 };
116 
117 // Converts a ComputationType to a string.
118 std::string ComputationTypeString(ComputationType ty);
119 
120 std::ostream &operator<<(std::ostream &os, ComputationType ty);
121 
122 using dnn::DataType;
123 using dnn::ToDataType;
124 
125 // Converts a ComputationType to a string.
126 std::string DataTypeString(DataType ty);
127 
128 std::ostream &operator<<(std::ostream &os, DataType ty);
129 
130 // Opaque identifier for an "algorithm" used by a blas routine.  This functions
131 // as a hint to the blas library.
132 typedef int64_t AlgorithmType;
133 constexpr AlgorithmType kDefaultAlgorithm = -1;
134 constexpr AlgorithmType kDefaultBlasGemm = -2;
135 constexpr AlgorithmType kDefaultBlasGemv = -3;
136 constexpr AlgorithmType kNoAlgorithm = -4;
137 
138 // blas uses -1 to represent the default algorithm. This happens to match up
139 // with the CUBLAS_GEMM_DFALT constant, so cuda_blas.cc is using static_cast
140 // to convert from AlgorithmType to cublasGemmAlgo_t, and uses a static_assert
141 // to ensure that this assumption does not break.
142 // If another blas implementation uses a different value for the default
143 // algorithm, then it needs to convert kDefaultGemmAlgo to that value
144 // (e.g. via a function called ToWhateverGemmAlgo).
145 constexpr AlgorithmType kDefaultGemmAlgo = -1;
146 
147 // Describes the result of a performance experiment, usually timing the speed of
148 // a particular AlgorithmType.
149 //
150 // If the call we were benchmarking failed (a common occurrence; not all
151 // algorithms are valid for all calls), is_valid() will be false.
152 class ProfileResult {
153  public:
is_valid()154   bool is_valid() const { return is_valid_; }
set_is_valid(bool val)155   void set_is_valid(bool val) { is_valid_ = val; }
algorithm()156   AlgorithmType algorithm() const { return algorithm_; }
set_algorithm(AlgorithmType val)157   void set_algorithm(AlgorithmType val) { algorithm_ = val; }
elapsed_time_in_ms()158   float elapsed_time_in_ms() const { return elapsed_time_in_ms_; }
set_elapsed_time_in_ms(float val)159   void set_elapsed_time_in_ms(float val) { elapsed_time_in_ms_ = val; }
160 
161  private:
162   bool is_valid_ = false;
163   AlgorithmType algorithm_ = kDefaultAlgorithm;
164   float elapsed_time_in_ms_ = std::numeric_limits<float>::max();
165 };
166 
167 class AlgorithmConfig {
168  public:
AlgorithmConfig()169   AlgorithmConfig() : algorithm_(kDefaultAlgorithm) {}
AlgorithmConfig(AlgorithmType algorithm)170   explicit AlgorithmConfig(AlgorithmType algorithm) : algorithm_(algorithm) {}
algorithm()171   AlgorithmType algorithm() const { return algorithm_; }
set_algorithm(AlgorithmType val)172   void set_algorithm(AlgorithmType val) { algorithm_ = val; }
173   bool operator==(const AlgorithmConfig &other) const {
174     return this->algorithm_ == other.algorithm_;
175   }
176   bool operator!=(const AlgorithmConfig &other) const {
177     return !(*this == other);
178   }
179   std::string ToString() const;
180 
181  private:
182   AlgorithmType algorithm_;
183 };
184 
185 // Opaque identifier specifying the precision to use in gemm calls.
186 typedef int64_t ComputePrecision;
187 constexpr ComputePrecision kDefaultComputePrecision = 0;
188 
189 // This struct contains the metadata of a matrix, e.g., its base address and
190 // dimensions.
191 struct MatrixDescriptor {
192   DeviceMemoryBase data;
193   int64_t leading_dim_stride;
194   int64_t batch_stride;
195   Transpose transpose;
196 
197   template <typename T>
castMatrixDescriptor198   DeviceMemory<T> cast() const {
199     return DeviceMemory<T>(data);
200   }
201 };
202 
203 // BLAS support interface -- this can be derived from a GPU executor when the
204 // underlying platform has an BLAS library implementation available. See
205 // StreamExecutor::AsBlas().
206 //
207 // Thread-hostile: CUDA associates a CUDA-context with a particular thread in
208 // the system. Any operation that a user attempts to perform by enqueueing BLAS
209 // operations on a thread not-associated with the CUDA-context has unknown
210 // behavior at the current time; see b/13176597
211 class BlasSupport {
212  public:
~BlasSupport()213   virtual ~BlasSupport() {}
214 
215   // Performs a BLAS y <- ax+y operation.
216   virtual bool DoBlasAxpy(Stream *stream, uint64_t elem_count, float alpha,
217                           const DeviceMemory<float> &x, int incx,
218                           DeviceMemory<float> *y, int incy) = 0;
219   virtual bool DoBlasAxpy(Stream *stream, uint64_t elem_count, double alpha,
220                           const DeviceMemory<double> &x, int incx,
221                           DeviceMemory<double> *y, int incy) = 0;
222   virtual bool DoBlasAxpy(Stream *stream, uint64_t elem_count,
223                           std::complex<float> alpha,
224                           const DeviceMemory<std::complex<float>> &x, int incx,
225                           DeviceMemory<std::complex<float>> *y, int incy) = 0;
226   virtual bool DoBlasAxpy(Stream *stream, uint64_t elem_count,
227                           std::complex<double> alpha,
228                           const DeviceMemory<std::complex<double>> &x, int incx,
229                           DeviceMemory<std::complex<double>> *y, int incy) = 0;
230 
231   // Copies vector to another vector: y <- x.
232   virtual bool DoBlasCopy(Stream *stream, uint64_t elem_count,
233                           const DeviceMemory<float> &x, int incx,
234                           DeviceMemory<float> *y, int incy) = 0;
235   virtual bool DoBlasCopy(Stream *stream, uint64_t elem_count,
236                           const DeviceMemory<double> &x, int incx,
237                           DeviceMemory<double> *y, int incy) = 0;
238   virtual bool DoBlasCopy(Stream *stream, uint64_t elem_count,
239                           const DeviceMemory<std::complex<float>> &x, int incx,
240                           DeviceMemory<std::complex<float>> *y, int incy) = 0;
241   virtual bool DoBlasCopy(Stream *stream, uint64_t elem_count,
242                           const DeviceMemory<std::complex<double>> &x, int incx,
243                           DeviceMemory<std::complex<double>> *y, int incy) = 0;
244 
245   // Computes the product of a vector by a scalar: x <- a*x.
246   virtual bool DoBlasScal(Stream *stream, uint64_t elem_count, float alpha,
247                           DeviceMemory<float> *x, int incx) = 0;
248   virtual bool DoBlasScal(Stream *stream, uint64_t elem_count, double alpha,
249                           DeviceMemory<double> *x, int incx) = 0;
250   virtual bool DoBlasScal(Stream *stream, uint64_t elem_count, float alpha,
251                           DeviceMemory<std::complex<float>> *x, int incx) = 0;
252   virtual bool DoBlasScal(Stream *stream, uint64_t elem_count, double alpha,
253                           DeviceMemory<std::complex<double>> *x, int incx) = 0;
254   virtual bool DoBlasScal(Stream *stream, uint64_t elem_count,
255                           std::complex<float> alpha,
256                           DeviceMemory<std::complex<float>> *x, int incx) = 0;
257   virtual bool DoBlasScal(Stream *stream, uint64_t elem_count,
258                           std::complex<double> alpha,
259                           DeviceMemory<std::complex<double>> *x, int incx) = 0;
260 
261   // Computes a matrix-vector product using a general matrix.
262   //
263   //     y <- alpha * a * x + beta * y,
264   // or
265   //     y <- alpha * a' * x + beta * y,
266   // or
267   //     y <- alpha * conj(a') * x + beta * y,
268   //
269   // alpha and beta are scalars; a is an m-by-n general matrix; x is a vector
270   // with n(trans==kNoTranspose)/m(otherwise) elements;
271   // y is a vector with m(trans==kNoTranspose)/n(otherwise) elements.
272   virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m,
273                           uint64_t n, float alpha, const DeviceMemory<float> &a,
274                           int lda, const DeviceMemory<float> &x, int incx,
275                           float beta, DeviceMemory<float> *y, int incy) = 0;
276   virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m,
277                           uint64_t n, double alpha,
278                           const DeviceMemory<double> &a, int lda,
279                           const DeviceMemory<double> &x, int incx, double beta,
280                           DeviceMemory<double> *y, int incy) = 0;
281   virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m,
282                           uint64_t n, std::complex<float> alpha,
283                           const DeviceMemory<std::complex<float>> &a, int lda,
284                           const DeviceMemory<std::complex<float>> &x, int incx,
285                           std::complex<float> beta,
286                           DeviceMemory<std::complex<float>> *y, int incy) = 0;
287   virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m,
288                           uint64_t n, std::complex<double> alpha,
289                           const DeviceMemory<std::complex<double>> &a, int lda,
290                           const DeviceMemory<std::complex<double>> &x, int incx,
291                           std::complex<double> beta,
292                           DeviceMemory<std::complex<double>> *y, int incy) = 0;
293 
294   virtual bool DoBlasGemvWithProfiling(
295       Stream *stream, blas::Transpose trans, uint64_t m, uint64 n, float alpha,
296       const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
297       int incx, float beta, DeviceMemory<float> *y, int incy,
298       ProfileResult *output_profile_result) = 0;
299   virtual bool DoBlasGemvWithProfiling(
300       Stream *stream, blas::Transpose trans, uint64_t m, uint64 n, double alpha,
301       const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
302       int incx, double beta, DeviceMemory<double> *y, int incy,
303       ProfileResult *output_profile_result) = 0;
304   virtual bool DoBlasGemvWithProfiling(
305       Stream *stream, blas::Transpose trans, uint64_t m, uint64 n,
306       std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a,
307       int lda, const DeviceMemory<std::complex<float>> &x, int incx,
308       std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
309       ProfileResult *output_profile_result) = 0;
310   virtual bool DoBlasGemvWithProfiling(
311       Stream *stream, blas::Transpose trans, uint64_t m, uint64 n,
312       std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a,
313       int lda, const DeviceMemory<std::complex<double>> &x, int incx,
314       std::complex<double> beta, DeviceMemory<std::complex<double>> *y,
315       int incy, ProfileResult *output_profile_result) = 0;
316 
317   // Computes a matrix-vector product using a symmetric band matrix.
318   //
319   //     y <- alpha * a * x + beta * y,
320   //
321   // alpha and beta are scalars; a is an n-by-n symmetric band matrix, with k
322   // super-diagonals; x and y are n-element vectors.
323   virtual bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64_t n,
324                           uint64_t k, float alpha, const DeviceMemory<float> &a,
325                           int lda, const DeviceMemory<float> &x, int incx,
326                           float beta, DeviceMemory<float> *y, int incy) = 0;
327   virtual bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64_t n,
328                           uint64_t k, double alpha,
329                           const DeviceMemory<double> &a, int lda,
330                           const DeviceMemory<double> &x, int incx, double beta,
331                           DeviceMemory<double> *y, int incy) = 0;
332 
333   // Computes a matrix-matrix product with general matrices:
334   //
335   //     c <- alpha * op(a) * op(b) + beta * c,
336   //
337   // op(X) is one of op(X) = X, or op(X) = X', or op(X) = conj(X'); alpha and
338   // beta are scalars; a, b, and c are matrices; op(a) is an m-by-k matrix;
339   // op(b) is a k-by-n matrix; c is an m-by-n matrix.
340   //
341   // Note: The half interface uses float precision internally; the version
342   // that uses half precision internally is not yet supported. There is no
343   // batched version of the half-precision interface.
344   //
345   // Alpha/beta type matches `dtype`, unless `dtype` is `Eigen::half`, in that
346   // case the expected alpha/beta type is `float`.
347   virtual port::Status DoBlasGemm(Stream *stream, blas::Transpose transa,
348                                   blas::Transpose transb, uint64_t m, uint64 n,
349                                   uint64_t k, DataType dtype, const void *alpha,
350                                   const DeviceMemoryBase &a, int lda,
351                                   const DeviceMemoryBase &b, int ldb,
352                                   const void *beta, DeviceMemoryBase *c,
353                                   int ldc, ComputePrecision precision) = 0;
354 
355   virtual bool DoBlasGemmWithProfiling(
356       Stream *stream, blas::Transpose transa, blas::Transpose transb,
357       uint64_t m, uint64_t n, uint64 k, float alpha,
358       const DeviceMemory<Eigen::half> &a, int lda,
359       const DeviceMemory<Eigen::half> &b, int ldb, float beta,
360       DeviceMemory<Eigen::half> *c, int ldc,
361       ProfileResult *output_profile_result) = 0;
362   virtual bool DoBlasGemmWithProfiling(
363       Stream *stream, blas::Transpose transa, blas::Transpose transb,
364       uint64_t m, uint64_t n, uint64 k, float alpha,
365       const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,
366       int ldb, float beta, DeviceMemory<float> *c, int ldc,
367       ProfileResult *output_profile_result) = 0;
368   virtual bool DoBlasGemmWithProfiling(
369       Stream *stream, blas::Transpose transa, blas::Transpose transb,
370       uint64_t m, uint64_t n, uint64 k, double alpha,
371       const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,
372       int ldb, double beta, DeviceMemory<double> *c, int ldc,
373       ProfileResult *output_profile_result) = 0;
374   virtual bool DoBlasGemmWithProfiling(
375       Stream *stream, blas::Transpose transa, blas::Transpose transb,
376       uint64_t m, uint64_t n, uint64 k, std::complex<float> alpha,
377       const DeviceMemory<std::complex<float>> &a, int lda,
378       const DeviceMemory<std::complex<float>> &b, int ldb,
379       std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
380       ProfileResult *output_profile_result) = 0;
381   virtual bool DoBlasGemmWithProfiling(
382       Stream *stream, blas::Transpose transa, blas::Transpose transb,
383       uint64_t m, uint64_t n, uint64 k, std::complex<double> alpha,
384       const DeviceMemory<std::complex<double>> &a, int lda,
385       const DeviceMemory<std::complex<double>> &b, int ldb,
386       std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
387       ProfileResult *output_profile_result) = 0;
388 
389   // Gets a list of supported algorithms for DoBlasGemmWithAlgorithm.
390   virtual bool GetBlasGemmAlgorithms(
391       Stream *stream, std::vector<AlgorithmType> *out_algorithms) = 0;
392 
393   // Like DoBlasGemm, but accepts an algorithm and an compute type.
394   //
395   // The compute type lets you say (e.g.) that the inputs and outputs are
396   // Eigen::halfs, but you want the internal computations to be done with
397   // float32 precision.
398   //
399   // If output_profile_result is not null, a failure here does not put the
400   // stream in a failure state.  Instead, success/failure is indicated by
401   // output_profile_result->is_valid().  This lets you use this function for
402   // choosing the best algorithm among many (some of which may fail) without
403   // creating a new Stream for each attempt.
404   virtual port::Status DoBlasGemmWithAlgorithm(
405       Stream *stream, blas::Transpose transa, blas::Transpose transb,
406       uint64_t m, uint64_t n, uint64 k, const void *alpha,
407       const DeviceMemoryBase &a, DataType type_a, int lda,
408       const DeviceMemoryBase &b, DataType type_b, int ldb, const void *beta,
409       DeviceMemoryBase *c, DataType type_c, int ldc,
410       ComputationType computation_type, AlgorithmType algorithm,
411       ProfileResult *output_profile_result) = 0;
412 
413   virtual port::Status DoBlasGemmStridedBatchedWithAlgorithm(
414       Stream *stream, blas::Transpose transa, blas::Transpose transb,
415       uint64_t m, uint64_t n, uint64 k, const void *alpha,
416       const DeviceMemoryBase &a, DataType type_a, int lda, int64_t stride_a,
417       const DeviceMemoryBase &b, DataType type_b, int ldb, int64_t stride_b,
418       const void *beta, DeviceMemoryBase *c, DataType type_c, int ldc,
419       int64_t stride_c, int batch_count, ComputationType computation_type,
420       AlgorithmType algorithm, ProfileResult *output_profile_result) = 0;
421 
422   // Computes a batch of matrix-matrix product with general matrices.
423   // This is a batched version of DoBlasGemm.
424   // The batched GEMM computes matrix product for each input/output in a, b,
425   // and c, which contain batch_count DeviceMemory objects.
426   virtual bool DoBlasGemmBatched(
427       Stream *stream, blas::Transpose transa, blas::Transpose transb,
428       uint64_t m, uint64_t n, uint64 k, float alpha,
429       const port::ArraySlice<DeviceMemory<Eigen::half> *> &a,  // non-absl ok
430       int lda,
431       const port::ArraySlice<DeviceMemory<Eigen::half> *> &b,  // non-absl ok
432       int ldb, float beta,
433       const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,  // non-absl ok
434       int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0;
435   virtual bool DoBlasGemmBatched(
436       Stream *stream, blas::Transpose transa, blas::Transpose transb,
437       uint64_t m, uint64_t n, uint64 k, float alpha,
438       const port::ArraySlice<DeviceMemory<float> *> &a, int lda,  // non-absl ok
439       const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,  // non-absl ok
440       float beta,
441       const port::ArraySlice<DeviceMemory<float> *> &c,  // non-absl ok
442       int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0;
443   virtual bool DoBlasGemmBatched(
444       Stream *stream, blas::Transpose transa, blas::Transpose transb,
445       uint64_t m, uint64_t n, uint64 k, double alpha,
446       const port::ArraySlice<DeviceMemory<double> *> &a,  // non-absl ok
447       int lda,
448       const port::ArraySlice<DeviceMemory<double> *> &b,  // non-absl ok
449       int ldb, double beta,
450       const port::ArraySlice<DeviceMemory<double> *> &c,  // non-absl ok
451       int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0;
452   virtual bool DoBlasGemmBatched(
453       Stream *stream, blas::Transpose transa, blas::Transpose transb,
454       uint64_t m, uint64_t n, uint64 k, std::complex<float> alpha,
455       const DeviceMemorySlice<std::complex<float>> &a, int lda,
456       const DeviceMemorySlice<std::complex<float>> &b, int ldb,
457       std::complex<float> beta, const DeviceMemorySlice<std::complex<float>> &c,
458       int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0;
459   virtual bool DoBlasGemmBatched(
460       Stream *stream, blas::Transpose transa, blas::Transpose transb,
461       uint64_t m, uint64_t n, uint64 k, std::complex<double> alpha,
462       const DeviceMemorySlice<std::complex<double>> &a, int lda,
463       const DeviceMemorySlice<std::complex<double>> &b, int ldb,
464       std::complex<double> beta,
465       const DeviceMemorySlice<std::complex<double>> &c, int ldc,
466       int batch_count, ScratchAllocator *scratch_allocator) = 0;
467 
468   // Batched gemm with strides instead of pointer arrays.
469   virtual port::Status DoBlasGemmStridedBatched(
470       Stream *stream, blas::Transpose transa, blas::Transpose transb,
471       uint64_t m, uint64_t n, uint64 k, DataType dtype, const void *alpha,
472       const DeviceMemoryBase &a, int lda, int64_t stride_a,
473       const DeviceMemoryBase &b, int ldb, int64_t stride_b, const void *beta,
474       DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count) = 0;
475 
476   // Solves a triangular matrix equation.
477   //
478   //     op(a) * x = alpha * b,
479   // or
480   //     x * op(a) = alpha * b
481   //
482   // alpha is a scalar; x and b are m-by-n matrices; a is a unit, or non-unit,
483   // upper or lower triangular matrix; op(a) is one of op(a) = a, or op(a) = a',
484   // or op(a) = conj(a').
485   virtual bool DoBlasTrsm(Stream *stream, blas::Side side,
486                           blas::UpperLower uplo, blas::Transpose transa,
487                           blas::Diagonal diag, uint64_t m, uint64 n,
488                           float alpha, const DeviceMemory<float> &a, int lda,
489                           DeviceMemory<float> *b, int ldb) = 0;
490   virtual bool DoBlasTrsm(Stream *stream, blas::Side side,
491                           blas::UpperLower uplo, blas::Transpose transa,
492                           blas::Diagonal diag, uint64_t m, uint64 n,
493                           double alpha, const DeviceMemory<double> &a, int lda,
494                           DeviceMemory<double> *b, int ldb) = 0;
495   virtual bool DoBlasTrsm(Stream *stream, blas::Side side,
496                           blas::UpperLower uplo, blas::Transpose transa,
497                           blas::Diagonal diag, uint64_t m, uint64 n,
498                           std::complex<float> alpha,
499                           const DeviceMemory<std::complex<float>> &a, int lda,
500                           DeviceMemory<std::complex<float>> *b, int ldb) = 0;
501   virtual bool DoBlasTrsm(Stream *stream, blas::Side side,
502                           blas::UpperLower uplo, blas::Transpose transa,
503                           blas::Diagonal diag, uint64_t m, uint64 n,
504                           std::complex<double> alpha,
505                           const DeviceMemory<std::complex<double>> &a, int lda,
506                           DeviceMemory<std::complex<double>> *b, int ldb) = 0;
507 
508   // Same as DoBlasTrsm, but operates over a list of a's and b's.  The lists
509   // `as` and `bs` must have the same length.
510   virtual bool DoBlasTrsmBatched(Stream *stream, blas::Side side,
511                                  blas::UpperLower uplo, blas::Transpose transa,
512                                  blas::Diagonal diag, uint64_t m, uint64 n,
513                                  float alpha, const DeviceMemory<float *> &as,
514                                  int lda, DeviceMemory<float *> *bs, int ldb,
515                                  int batch_count) = 0;
516   virtual bool DoBlasTrsmBatched(Stream *stream, blas::Side side,
517                                  blas::UpperLower uplo, blas::Transpose transa,
518                                  blas::Diagonal diag, uint64_t m, uint64 n,
519                                  double alpha, const DeviceMemory<double *> &as,
520                                  int lda, DeviceMemory<double *> *bs, int ldb,
521                                  int batch_count) = 0;
522   virtual bool DoBlasTrsmBatched(Stream *stream, blas::Side side,
523                                  blas::UpperLower uplo, blas::Transpose transa,
524                                  blas::Diagonal diag, uint64_t m, uint64 n,
525                                  std::complex<float> alpha,
526                                  const DeviceMemory<std::complex<float> *> &as,
527                                  int lda,
528                                  DeviceMemory<std::complex<float> *> *bs,
529                                  int ldb, int batch_count) = 0;
530   virtual bool DoBlasTrsmBatched(Stream *stream, blas::Side side,
531                                  blas::UpperLower uplo, blas::Transpose transa,
532                                  blas::Diagonal diag, uint64_t m, uint64 n,
533                                  std::complex<double> alpha,
534                                  const DeviceMemory<std::complex<double> *> &as,
535                                  int lda,
536                                  DeviceMemory<std::complex<double> *> *bs,
537                                  int ldb, int batch_count) = 0;
538 
539   virtual port::Status GetVersion(std::string *version) = 0;
540 
541  protected:
BlasSupport()542   BlasSupport() {}
543 
544  private:
545   SE_DISALLOW_COPY_AND_ASSIGN(BlasSupport);
546 };
547 
548 // Macro used to quickly declare overrides for abstract virtuals in the
549 // BlasSupport base class.
550 #define TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES                  \
551   bool DoBlasAxpy(Stream *stream, uint64_t elem_count, float alpha,            \
552                   const DeviceMemory<float> &x, int incx,                      \
553                   DeviceMemory<float> *y, int incy) override;                  \
554   bool DoBlasAxpy(Stream *stream, uint64_t elem_count, double alpha,           \
555                   const DeviceMemory<double> &x, int incx,                     \
556                   DeviceMemory<double> *y, int incy) override;                 \
557   bool DoBlasAxpy(Stream *stream, uint64_t elem_count,                         \
558                   std::complex<float> alpha,                                   \
559                   const DeviceMemory<std::complex<float>> &x, int incx,        \
560                   DeviceMemory<std::complex<float>> *y, int incy) override;    \
561   bool DoBlasAxpy(Stream *stream, uint64_t elem_count,                         \
562                   std::complex<double> alpha,                                  \
563                   const DeviceMemory<std::complex<double>> &x, int incx,       \
564                   DeviceMemory<std::complex<double>> *y, int incy) override;   \
565   bool DoBlasCopy(Stream *stream, uint64_t elem_count,                         \
566                   const DeviceMemory<float> &x, int incx,                      \
567                   DeviceMemory<float> *y, int incy) override;                  \
568   bool DoBlasCopy(Stream *stream, uint64_t elem_count,                         \
569                   const DeviceMemory<double> &x, int incx,                     \
570                   DeviceMemory<double> *y, int incy) override;                 \
571   bool DoBlasCopy(Stream *stream, uint64_t elem_count,                         \
572                   const DeviceMemory<std::complex<float>> &x, int incx,        \
573                   DeviceMemory<std::complex<float>> *y, int incy) override;    \
574   bool DoBlasCopy(Stream *stream, uint64_t elem_count,                         \
575                   const DeviceMemory<std::complex<double>> &x, int incx,       \
576                   DeviceMemory<std::complex<double>> *y, int incy) override;   \
577   bool DoBlasScal(Stream *stream, uint64_t elem_count, float alpha,            \
578                   DeviceMemory<float> *x, int incx) override;                  \
579   bool DoBlasScal(Stream *stream, uint64_t elem_count, double alpha,           \
580                   DeviceMemory<double> *x, int incx) override;                 \
581   bool DoBlasScal(Stream *stream, uint64_t elem_count, float alpha,            \
582                   DeviceMemory<std::complex<float>> *x, int incx) override;    \
583   bool DoBlasScal(Stream *stream, uint64_t elem_count, double alpha,           \
584                   DeviceMemory<std::complex<double>> *x, int incx) override;   \
585   bool DoBlasScal(Stream *stream, uint64_t elem_count,                         \
586                   std::complex<float> alpha,                                   \
587                   DeviceMemory<std::complex<float>> *x, int incx) override;    \
588   bool DoBlasScal(Stream *stream, uint64_t elem_count,                         \
589                   std::complex<double> alpha,                                  \
590                   DeviceMemory<std::complex<double>> *x, int incx) override;   \
591   bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m, uint64 n, \
592                   float alpha, const DeviceMemory<float> &a, int lda,          \
593                   const DeviceMemory<float> &x, int incx, float beta,          \
594                   DeviceMemory<float> *y, int incy) override;                  \
595   bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m, uint64 n, \
596                   double alpha, const DeviceMemory<double> &a, int lda,        \
597                   const DeviceMemory<double> &x, int incx, double beta,        \
598                   DeviceMemory<double> *y, int incy) override;                 \
599   bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m, uint64 n, \
600                   std::complex<float> alpha,                                   \
601                   const DeviceMemory<std::complex<float>> &a, int lda,         \
602                   const DeviceMemory<std::complex<float>> &x, int incx,        \
603                   std::complex<float> beta,                                    \
604                   DeviceMemory<std::complex<float>> *y, int incy) override;    \
605   bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m, uint64 n, \
606                   std::complex<double> alpha,                                  \
607                   const DeviceMemory<std::complex<double>> &a, int lda,        \
608                   const DeviceMemory<std::complex<double>> &x, int incx,       \
609                   std::complex<double> beta,                                   \
610                   DeviceMemory<std::complex<double>> *y, int incy) override;   \
611   bool DoBlasGemvWithProfiling(                                                \
612       Stream *stream, blas::Transpose trans, uint64_t m, uint64 n,             \
613       float alpha, const DeviceMemory<float> &a, int lda,                      \
614       const DeviceMemory<float> &x, int incx, float beta,                      \
615       DeviceMemory<float> *y, int incy,                                        \
616       blas::ProfileResult *output_profile_result) override;                    \
617   bool DoBlasGemvWithProfiling(                                                \
618       Stream *stream, blas::Transpose trans, uint64_t m, uint64 n,             \
619       double alpha, const DeviceMemory<double> &a, int lda,                    \
620       const DeviceMemory<double> &x, int incx, double beta,                    \
621       DeviceMemory<double> *y, int incy,                                       \
622       blas::ProfileResult *output_profile_result) override;                    \
623   bool DoBlasGemvWithProfiling(                                                \
624       Stream *stream, blas::Transpose trans, uint64_t m, uint64 n,             \
625       std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a,   \
626       int lda, const DeviceMemory<std::complex<float>> &x, int incx,           \
627       std::complex<float> beta, DeviceMemory<std::complex<float>> *y,          \
628       int incy, blas::ProfileResult *output_profile_result) override;          \
629   bool DoBlasGemvWithProfiling(                                                \
630       Stream *stream, blas::Transpose trans, uint64_t m, uint64 n,             \
631       std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a, \
632       int lda, const DeviceMemory<std::complex<double>> &x, int incx,          \
633       std::complex<double> beta, DeviceMemory<std::complex<double>> *y,        \
634       int incy, blas::ProfileResult *output_profile_result) override;          \
635   bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64_t n, uint64 k, \
636                   float alpha, const DeviceMemory<float> &a, int lda,          \
637                   const DeviceMemory<float> &x, int incx, float beta,          \
638                   DeviceMemory<float> *y, int incy) override;                  \
639   bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64_t n, uint64 k, \
640                   double alpha, const DeviceMemory<double> &a, int lda,        \
641                   const DeviceMemory<double> &x, int incx, double beta,        \
642                   DeviceMemory<double> *y, int incy) override;                 \
643   port::Status DoBlasGemm(                                                     \
644       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
645       uint64_t m, uint64 n, uint64 k, blas::DataType dtype, const void *alpha, \
646       const DeviceMemoryBase &a, int lda, const DeviceMemoryBase &b, int ldb,  \
647       const void *beta, DeviceMemoryBase *c, int ldc,                          \
648       blas::ComputePrecision precision) override;                              \
649   bool DoBlasGemmWithProfiling(                                                \
650       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
651       uint64_t m, uint64 n, uint64 k, float alpha,                             \
652       const DeviceMemory<Eigen::half> &a, int lda,                             \
653       const DeviceMemory<Eigen::half> &b, int ldb, float beta,                 \
654       DeviceMemory<Eigen::half> *c, int ldc,                                   \
655       blas::ProfileResult *output_profile_result) override;                    \
656   bool DoBlasGemmWithProfiling(                                                \
657       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
658       uint64_t m, uint64 n, uint64 k, float alpha,                             \
659       const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,     \
660       int ldb, float beta, DeviceMemory<float> *c, int ldc,                    \
661       blas::ProfileResult *output_profile_result) override;                    \
662   bool DoBlasGemmWithProfiling(                                                \
663       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
664       uint64_t m, uint64 n, uint64 k, double alpha,                            \
665       const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,   \
666       int ldb, double beta, DeviceMemory<double> *c, int ldc,                  \
667       blas::ProfileResult *output_profile_result) override;                    \
668   bool DoBlasGemmWithProfiling(                                                \
669       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
670       uint64_t m, uint64 n, uint64 k, std::complex<float> alpha,               \
671       const DeviceMemory<std::complex<float>> &a, int lda,                     \
672       const DeviceMemory<std::complex<float>> &b, int ldb,                     \
673       std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, \
674       blas::ProfileResult *output_profile_result) override;                    \
675   bool DoBlasGemmWithProfiling(                                                \
676       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
677       uint64_t m, uint64 n, uint64 k, std::complex<double> alpha,              \
678       const DeviceMemory<std::complex<double>> &a, int lda,                    \
679       const DeviceMemory<std::complex<double>> &b, int ldb,                    \
680       std::complex<double> beta, DeviceMemory<std::complex<double>> *c,        \
681       int ldc, blas::ProfileResult *output_profile_result) override;           \
682   bool GetBlasGemmAlgorithms(Stream *stream,                                   \
683                              std::vector<blas::AlgorithmType> *out_algorithms) \
684       override;                                                                \
685   port::Status DoBlasGemmWithAlgorithm(                                        \
686       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
687       uint64_t m, uint64 n, uint64 k, const void *alpha,                       \
688       const DeviceMemoryBase &a, blas::DataType type_a, int lda,               \
689       const DeviceMemoryBase &b, blas::DataType type_b, int ldb,               \
690       const void *beta, DeviceMemoryBase *c, blas::DataType type_c, int ldc,   \
691       blas::ComputationType computation_type, blas::AlgorithmType algorithm,   \
692       blas::ProfileResult *output_profile_result) override;                    \
693   bool DoBlasGemmBatched(                                                      \
694       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
695       uint64_t m, uint64 n, uint64 k, float alpha,                             \
696       const DeviceMemorySlice<Eigen::half> &a, int lda,                        \
697       const DeviceMemorySlice<Eigen::half> &b, int ldb, float beta,            \
698       const DeviceMemorySlice<Eigen::half> &c, int ldc, int batch_count,       \
699       ScratchAllocator *scratch_allocator) override;                           \
700   bool DoBlasGemmBatched(                                                      \
701       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
702       uint64_t m, uint64 n, uint64 k, float alpha,                             \
703       const DeviceMemorySlice<float> &a, int lda,                              \
704       const DeviceMemorySlice<float> &b, int ldb, float beta,                  \
705       const DeviceMemorySlice<float> &c, int ldc, int batch_count,             \
706       ScratchAllocator *scratch_allocator) override;                           \
707   bool DoBlasGemmBatched(                                                      \
708       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
709       uint64_t m, uint64 n, uint64 k, double alpha,                            \
710       const DeviceMemorySlice<double> &a, int lda,                             \
711       const DeviceMemorySlice<double> &b, int ldb, double beta,                \
712       const DeviceMemorySlice<double> &c, int ldc, int batch_count,            \
713       ScratchAllocator *scratch_allocator) override;                           \
714   bool DoBlasGemmBatched(                                                      \
715       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
716       uint64_t m, uint64 n, uint64 k, std::complex<float> alpha,               \
717       const DeviceMemorySlice<std::complex<float>> &a, int lda,                \
718       const DeviceMemorySlice<std::complex<float>> &b, int ldb,                \
719       std::complex<float> beta,                                                \
720       const DeviceMemorySlice<std::complex<float>> &c, int ldc,                \
721       int batch_count, ScratchAllocator *scratch_allocator) override;          \
722   bool DoBlasGemmBatched(                                                      \
723       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
724       uint64_t m, uint64 n, uint64 k, std::complex<double> alpha,              \
725       const DeviceMemorySlice<std::complex<double>> &a, int lda,               \
726       const DeviceMemorySlice<std::complex<double>> &b, int ldb,               \
727       std::complex<double> beta,                                               \
728       const DeviceMemorySlice<std::complex<double>> &c, int ldc,               \
729       int batch_count, ScratchAllocator *scratch_allocator) override;          \
730   port::Status DoBlasGemmStridedBatched(                                       \
731       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
732       uint64_t m, uint64 n, uint64 k, blas::DataType dtype, const void *alpha, \
733       const DeviceMemoryBase &a, int lda, int64_t stride_a,                    \
734       const DeviceMemoryBase &b, int ldb, int64_t stride_b, const void *beta,  \
735       DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count);        \
736   port::Status DoBlasGemmStridedBatchedWithAlgorithm(                          \
737       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
738       uint64_t m, uint64 n, uint64 k, const void *alpha,                       \
739       const DeviceMemoryBase &a, blas::DataType type_a, int lda,               \
740       int64_t stride_a, const DeviceMemoryBase &b, blas::DataType type_b,      \
741       int ldb, int64_t stride_b, const void *beta, DeviceMemoryBase *c,        \
742       blas::DataType type_c, int ldc, int64_t stride_c, int batch_count,       \
743       blas::ComputationType computation_type, blas::AlgorithmType algorithm,   \
744       blas::ProfileResult *output_profile_result) override;                    \
745   bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
746                   blas::Transpose transa, blas::Diagonal diag, uint64_t m,     \
747                   uint64_t n, float alpha, const DeviceMemory<float> &a,       \
748                   int lda, DeviceMemory<float> *b, int ldb) override;          \
749   bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
750                   blas::Transpose transa, blas::Diagonal diag, uint64_t m,     \
751                   uint64_t n, double alpha, const DeviceMemory<double> &a,     \
752                   int lda, DeviceMemory<double> *b, int ldb) override;         \
753   bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
754                   blas::Transpose transa, blas::Diagonal diag, uint64_t m,     \
755                   uint64_t n, std::complex<float> alpha,                       \
756                   const DeviceMemory<std::complex<float>> &a, int lda,         \
757                   DeviceMemory<std::complex<float>> *b, int ldb) override;     \
758   bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
759                   blas::Transpose transa, blas::Diagonal diag, uint64_t m,     \
760                   uint64_t n, std::complex<double> alpha,                      \
761                   const DeviceMemory<std::complex<double>> &a, int lda,        \
762                   DeviceMemory<std::complex<double>> *b, int ldb) override;    \
763   bool DoBlasTrsmBatched(                                                      \
764       Stream *stream, blas::Side side, blas::UpperLower uplo,                  \
765       blas::Transpose transa, blas::Diagonal diag, uint64_t m, uint64 n,       \
766       float alpha, const DeviceMemory<float *> &as, int lda,                   \
767       DeviceMemory<float *> *bs, int ldb, int batch_count) override;           \
768   bool DoBlasTrsmBatched(                                                      \
769       Stream *stream, blas::Side side, blas::UpperLower uplo,                  \
770       blas::Transpose transa, blas::Diagonal diag, uint64_t m, uint64 n,       \
771       double alpha, const DeviceMemory<double *> &as, int lda,                 \
772       DeviceMemory<double *> *bs, int ldb, int batch_count) override;          \
773   bool DoBlasTrsmBatched(Stream *stream, blas::Side side,                      \
774                          blas::UpperLower uplo, blas::Transpose transa,        \
775                          blas::Diagonal diag, uint64_t m, uint64 n,            \
776                          std::complex<float> alpha,                            \
777                          const DeviceMemory<std::complex<float> *> &as,        \
778                          int lda, DeviceMemory<std::complex<float> *> *bs,     \
779                          int ldb, int batch_count) override;                   \
780   bool DoBlasTrsmBatched(Stream *stream, blas::Side side,                      \
781                          blas::UpperLower uplo, blas::Transpose transa,        \
782                          blas::Diagonal diag, uint64_t m, uint64 n,            \
783                          std::complex<double> alpha,                           \
784                          const DeviceMemory<std::complex<double> *> &as,       \
785                          int lda, DeviceMemory<std::complex<double> *> *bs,    \
786                          int ldb, int batch_count) override;                   \
787   port::Status GetVersion(std::string *version) override;
788 
789 }  // namespace blas
790 }  // namespace stream_executor
791 
792 #endif  // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_BLAS_H_
793