• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // CUDA-specific support for BLAS functionality -- this wraps the cuBLAS library
17 // capabilities, and is only included into CUDA implementation code -- it will
18 // not introduce cuda headers into other code.
19 
20 #ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_
21 #define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_
22 
23 #include "absl/synchronization/mutex.h"
24 #include "third_party/gpus/cuda/include/cublasLt.h"
25 #include "third_party/gpus/cuda/include/cublas_v2.h"
26 #include "third_party/gpus/cuda/include/cuda.h"
27 #include "tensorflow/core/platform/thread_annotations.h"
28 #include "tensorflow/stream_executor/blas.h"
29 #include "tensorflow/stream_executor/host_or_device_scalar.h"
30 #include "tensorflow/stream_executor/platform/port.h"
31 #include "tensorflow/stream_executor/plugin_registry.h"
32 
33 typedef struct cublasContext *cublasHandle_t;
34 
35 namespace stream_executor {
36 
37 class Stream;
38 
39 namespace gpu {
40 
41 // Opaque and unique identifier for the cuBLAS plugin.
42 extern const PluginId kCuBlasPlugin;
43 
44 class GpuExecutor;
45 
46 // BLAS plugin for CUDA platform via cuBLAS library.
47 //
48 // This satisfies the platform-agnostic BlasSupport interface.
49 //
50 // Note that the cuBLAS handle that this encapsulates is implicitly tied to the
51 // context (and, as a result, the device) that the parent GpuExecutor is tied
52 // to. This simply happens as an artifact of creating the cuBLAS handle when a
53 // CUDA context is active.
54 //
55 // Thread-safe post-initialization.
56 class CUDABlas : public blas::BlasSupport {
57  public:
58   explicit CUDABlas(GpuExecutor *parent);
59 
60   // Allocates a cuBLAS handle.
61   bool Init();
62 
63   // Releases the cuBLAS handle, if present.
64   ~CUDABlas() override;
65 
66   TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES
67 
68  private:
69   // Tells cuBLAS to enqueue the BLAS operation onto a particular Stream.
70   //
71   // cuBLAS is stateful, and only be associated with one stream (in order to
72   // enqueue dispatch) at a given time. As a result, this generally must be
73   // invoked before calling into cuBLAS.
74   bool SetStream(Stream *stream) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
75 
76   // Returns the underlying CUDA stream.
77   cudaStream_t CUDAStream(Stream *stream);
78 
79   // A helper function that calls the real cuBLAS function together with error
80   // handling.
81   //
82   // cublas_func:        cuBLAS function pointer.
83   // cublas_name:        cuBLAS function name.
84   // stream:             Stream to enqueue the BLAS operation onto.
85   // pointer_mode_host:  Indicate if the pointer to a scalar value is from host
86   //                     (true) or device (false).
87   // args:               Arguments of cuBLAS function.
88   template <typename FuncT, typename... Args>
89   port::Status DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
90                                   bool pointer_mode_host,
91                                   cublasMath_t math_type, Args... args);
92 
93   // Convenience functions that call DoBlasInternalImpl with err_on_failure=true
94   // and math_type=CUBLAS_DEFAULT_MATH.
95   template <typename FuncT, typename... Args>
DoBlasInternal(FuncT cublas_func,Stream * stream,bool pointer_mode_host,Args...args)96   bool DoBlasInternal(FuncT cublas_func, Stream *stream, bool pointer_mode_host,
97                       Args... args) {
98     return DoBlasInternalImpl(cublas_func, stream, pointer_mode_host,
99                               CUBLAS_DEFAULT_MATH, args...)
100         .ok();
101   }
102 
103   // A helper function to implement DoBlasGemmBatched interfaces for generic
104   // types.
105   template <typename T, typename Scalar, typename FuncT>
106   port::Status DoBlasGemmBatchedInternal(
107       FuncT cublas_func, Stream *stream, blas::Transpose transa,
108       blas::Transpose transb, uint64 m, uint64 n, uint64 k, Scalar alpha,
109       const port::ArraySlice<DeviceMemory<T> *> &a_array, int lda,
110       const port::ArraySlice<DeviceMemory<T> *> &b_array, int ldb, Scalar beta,
111       const port::ArraySlice<DeviceMemory<T> *> &c_array, int ldc,
112       int batch_count, ScratchAllocator *scratch_allocator);
113 
114   // Helper function for implementing DoBlasGemmWithProfiling.
115   template <typename T, typename ParamType>
116   bool DoBlasGemmWithProfilingImpl(
117       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
118       uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a,
119       int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta,
120       DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result);
121 
122   // Helper function for implementing DoBlasGemvWithProfiling.
123   template <typename T>
124   bool DoBlasGemvWithProfilingImpl(Stream *stream, blas::Transpose trans,
125                                    uint64 m, uint64 n, const T &alpha,
126                                    const DeviceMemory<T> &a, int lda,
127                                    const DeviceMemory<T> &x, int incx,
128                                    const T &beta, DeviceMemory<T> *y, int incy,
129                                    blas::ProfileResult *output_profile_result);
130 
131   // Helper function for implementing DoBlasLtMatmul.
132   bool DoBlasLtMatmulInternal(Stream *stream, bool err_on_failure,
133                               const blas::IBlasLtMatmulPlan *plan,
134                               const HostOrDeviceScalar<void> &alpha,
135                               DeviceMemoryBase a, DeviceMemoryBase b,
136                               const HostOrDeviceScalar<void> &beta,
137                               DeviceMemoryBase c, DeviceMemoryBase d,
138                               ScratchAllocator *scratch_allocator,
139                               const blas::IBlasLtMatmulAlgorithm *algorithm,
140                               DeviceMemoryBase bias);
141 
142   // Helper function for implementing GetBlasLtMatmulAlgorithms.
143   port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
144   GetBlasLtMatmulAlgorithmsInternal(const blas::IBlasLtMatmulPlan *plan,
145                                     size_t max_workspace_size,
146                                     int max_algorithm_count,
147                                     bool for_remainder_batch = false);
148 
149   // Guards the cuBLAS handle for this device.
150   absl::Mutex mu_;
151 
152   // GpuExecutor which instantiated this CUDABlas.
153   // Immutable post-initialization.
154   GpuExecutor *parent_;
155 
156   // cuBLAS library handle on the device.
157   cublasHandle_t blas_ TF_GUARDED_BY(mu_);
158 
159 #if CUDA_VERSION >= 11000
160   // cuBLASLt library handle on the device.
161   cublasLtHandle_t blasLt_ TF_GUARDED_BY(mu_);
162 #endif
163 
164   SE_DISALLOW_COPY_AND_ASSIGN(CUDABlas);
165 };
166 
167 }  // namespace gpu
168 }  // namespace stream_executor
169 
170 #endif  // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_
171