• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // CUDA-specific support for BLAS functionality -- this wraps the cuBLAS library
17 // capabilities, and is only included into CUDA implementation code -- it will
18 // not introduce cuda headers into other code.
19 
20 #ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_
21 #define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_
22 
23 #include "absl/synchronization/mutex.h"
24 #include "third_party/gpus/cuda/include/cublasLt.h"
25 #include "third_party/gpus/cuda/include/cublas_v2.h"
26 #include "third_party/gpus/cuda/include/cuda.h"
27 #include "tensorflow/core/platform/thread_annotations.h"
28 #include "tensorflow/stream_executor/blas.h"
29 #include "tensorflow/stream_executor/host_or_device_scalar.h"
30 #include "tensorflow/stream_executor/platform/port.h"
31 #include "tensorflow/stream_executor/plugin_registry.h"
32 
33 typedef struct cublasContext *cublasHandle_t;
34 
35 namespace stream_executor {
36 
37 class Stream;
38 
39 namespace gpu {
40 
41 // Opaque and unique identifier for the cuBLAS plugin.
42 extern const PluginId kCuBlasPlugin;
43 
44 class GpuExecutor;
45 
46 // BLAS plugin for CUDA platform via cuBLAS library.
47 //
48 // This satisfies the platform-agnostic BlasSupport interface.
49 //
50 // Note that the cuBLAS handle that this encapsulates is implicitly tied to the
51 // context (and, as a result, the device) that the parent GpuExecutor is tied
52 // to. This simply happens as an artifact of creating the cuBLAS handle when a
53 // CUDA context is active.
54 //
55 // Thread-safe post-initialization.
56 class CUDABlas : public blas::BlasSupport {
57  public:
58   explicit CUDABlas(GpuExecutor *parent);
59 
60   // Allocates a cuBLAS handle.
61   bool Init();
62 
63   // Releases the cuBLAS handle, if present.
64   ~CUDABlas() override;
65 
66   TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES
67 
68  private:
69   // Tells cuBLAS to enqueue the BLAS operation onto a particular Stream.
70   //
71   // cuBLAS is stateful, and only be associated with one stream (in order to
72   // enqueue dispatch) at a given time. As a result, this generally must be
73   // invoked before calling into cuBLAS.
74   bool SetStream(Stream *stream) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
75 
76   // Returns the underlying CUDA stream.
77   cudaStream_t CUDAStream(Stream *stream);
78 
79   // A helper function that calls the real cuBLAS function together with error
80   // handling.
81   //
82   // cublas_func:        cuBLAS function pointer.
83   // cublas_name:        cuBLAS function name.
84   // stream:             Stream to enqueue the BLAS operation onto.
85   // pointer_mode_host:  Indicate if the pointer to a scalar value is from host
86   //                     (true) or device (false).
87   // err_on_failure:     Whether to print an error if the cublas function fails.
88   // args:               Arguments of cuBLAS function.
89   template <typename FuncT, typename... Args>
90   bool DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
91                           bool pointer_mode_host, bool err_on_failure,
92                           cublasMath_t math_type, Args... args);
93 
94   // Convenience functions that call DoBlasInternalImpl with err_on_failure=true
95   // and math_type=CUBLAS_DEFAULT_MATH.
96   template <typename FuncT, typename... Args>
DoBlasInternal(FuncT cublas_func,Stream * stream,bool pointer_mode_host,Args...args)97   bool DoBlasInternal(FuncT cublas_func, Stream *stream, bool pointer_mode_host,
98                       Args... args) {
99     return DoBlasInternalImpl(cublas_func, stream, pointer_mode_host,
100                               /*err_on_failure=*/true, CUBLAS_DEFAULT_MATH,
101                               args...);
102   }
103 
104   // A helper function to implement DoBlasGemmBatched interfaces for generic
105   // types.
106   template <typename T, typename Scalar, typename FuncT>
107   port::Status DoBlasGemmBatchedInternal(
108       FuncT cublas_func, Stream *stream, blas::Transpose transa,
109       blas::Transpose transb, uint64 m, uint64 n, uint64 k, Scalar alpha,
110       const port::ArraySlice<DeviceMemory<T> *> &a_array, int lda,
111       const port::ArraySlice<DeviceMemory<T> *> &b_array, int ldb, Scalar beta,
112       const port::ArraySlice<DeviceMemory<T> *> &c_array, int ldc,
113       int batch_count, ScratchAllocator *scratch_allocator);
114 
115   // Helper function for implementing DoBlasGemmWithAlgorithm.
116   template <typename InT, typename OutT, typename CompT>
117   bool DoBlasGemmWithAlgorithmImpl(
118       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
119       uint64 n, uint64 k, const HostOrDeviceScalar<CompT> &alpha,
120       const DeviceMemory<InT> &a, int lda, const DeviceMemory<InT> &b, int ldb,
121       const HostOrDeviceScalar<CompT> &beta, DeviceMemory<OutT> *c, int ldc,
122       blas::ComputationType computation_type, blas::AlgorithmType algorithm,
123       blas::ProfileResult *output_profile_result);
124 
125   // Helper function for implementing DoBlasGemmWithProfiling.
126   template <typename T, typename ParamType>
127   bool DoBlasGemmWithProfilingImpl(
128       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
129       uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a,
130       int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta,
131       DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result);
132 
133   // Helper function for implementing DoBlasGemvWithProfiling.
134   template <typename T>
135   bool DoBlasGemvWithProfilingImpl(Stream *stream, blas::Transpose trans,
136                                    uint64 m, uint64 n, const T &alpha,
137                                    const DeviceMemory<T> &a, int lda,
138                                    const DeviceMemory<T> &x, int incx,
139                                    const T &beta, DeviceMemory<T> *y, int incy,
140                                    blas::ProfileResult *output_profile_result);
141 
142   // Helper function for implementing DoBlasLtMatmul.
143   bool DoBlasLtMatmulInternal(Stream *stream, bool err_on_failure,
144                               const blas::IBlasLtMatmulPlan *plan,
145                               const HostOrDeviceScalar<void> &alpha,
146                               DeviceMemoryBase a, DeviceMemoryBase b,
147                               const HostOrDeviceScalar<void> &beta,
148                               DeviceMemoryBase c, DeviceMemoryBase d,
149                               ScratchAllocator *scratch_allocator,
150                               const blas::IBlasLtMatmulAlgorithm *algorithm,
151                               DeviceMemoryBase bias);
152 
153   // Helper function for implementing GetBlasLtMatmulAlgorithms.
154   port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
155   GetBlasLtMatmulAlgorithmsInternal(const blas::IBlasLtMatmulPlan *plan,
156                                     size_t max_workspace_size,
157                                     int max_algorithm_count,
158                                     bool for_remainder_batch = false);
159 
160   // Guards the cuBLAS handle for this device.
161   absl::Mutex mu_;
162 
163   // GpuExecutor which instantiated this CUDABlas.
164   // Immutable post-initialization.
165   GpuExecutor *parent_;
166 
167   // cuBLAS library handle on the device.
168   cublasHandle_t blas_ TF_GUARDED_BY(mu_);
169 
170 #if CUDA_VERSION >= 11000
171   // cuBLASLt library handle on the device.
172   cublasLtHandle_t blasLt_ GUARDED_BY(mu_);
173 #endif
174 
175   SE_DISALLOW_COPY_AND_ASSIGN(CUDABlas);
176 };
177 
178 }  // namespace gpu
179 }  // namespace stream_executor
180 
181 #endif  // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_
182