1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // CUDA-specific support for BLAS functionality -- this wraps the cuBLAS library 17 // capabilities, and is only included into CUDA implementation code -- it will 18 // not introduce cuda headers into other code. 19 20 #ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_ 21 #define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_ 22 23 #include "absl/synchronization/mutex.h" 24 #include "third_party/gpus/cuda/include/cublasLt.h" 25 #include "third_party/gpus/cuda/include/cublas_v2.h" 26 #include "third_party/gpus/cuda/include/cuda.h" 27 #include "tensorflow/core/platform/thread_annotations.h" 28 #include "tensorflow/stream_executor/blas.h" 29 #include "tensorflow/stream_executor/host_or_device_scalar.h" 30 #include "tensorflow/stream_executor/platform/port.h" 31 #include "tensorflow/stream_executor/plugin_registry.h" 32 33 typedef struct cublasContext *cublasHandle_t; 34 35 namespace stream_executor { 36 37 class Stream; 38 39 namespace gpu { 40 41 // Opaque and unique identifier for the cuBLAS plugin. 42 extern const PluginId kCuBlasPlugin; 43 44 class GpuExecutor; 45 46 // BLAS plugin for CUDA platform via cuBLAS library. 47 // 48 // This satisfies the platform-agnostic BlasSupport interface. 49 // 50 // Note that the cuBLAS handle that this encapsulates is implicitly tied to the 51 // context (and, as a result, the device) that the parent GpuExecutor is tied 52 // to. This simply happens as an artifact of creating the cuBLAS handle when a 53 // CUDA context is active. 54 // 55 // Thread-safe post-initialization. 56 class CUDABlas : public blas::BlasSupport { 57 public: 58 explicit CUDABlas(GpuExecutor *parent); 59 60 // Allocates a cuBLAS handle. 61 bool Init(); 62 63 // Releases the cuBLAS handle, if present. 64 ~CUDABlas() override; 65 66 TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES 67 68 private: 69 // Tells cuBLAS to enqueue the BLAS operation onto a particular Stream. 70 // 71 // cuBLAS is stateful, and only be associated with one stream (in order to 72 // enqueue dispatch) at a given time. As a result, this generally must be 73 // invoked before calling into cuBLAS. 74 bool SetStream(Stream *stream) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 75 76 // Returns the underlying CUDA stream. 77 cudaStream_t CUDAStream(Stream *stream); 78 79 // A helper function that calls the real cuBLAS function together with error 80 // handling. 81 // 82 // cublas_func: cuBLAS function pointer. 83 // cublas_name: cuBLAS function name. 84 // stream: Stream to enqueue the BLAS operation onto. 85 // pointer_mode_host: Indicate if the pointer to a scalar value is from host 86 // (true) or device (false). 87 // args: Arguments of cuBLAS function. 88 template <typename FuncT, typename... Args> 89 port::Status DoBlasInternalImpl(FuncT cublas_func, Stream *stream, 90 bool pointer_mode_host, 91 cublasMath_t math_type, Args... args); 92 93 // Convenience functions that call DoBlasInternalImpl with err_on_failure=true 94 // and math_type=CUBLAS_DEFAULT_MATH. 95 template <typename FuncT, typename... Args> DoBlasInternal(FuncT cublas_func,Stream * stream,bool pointer_mode_host,Args...args)96 bool DoBlasInternal(FuncT cublas_func, Stream *stream, bool pointer_mode_host, 97 Args... args) { 98 return DoBlasInternalImpl(cublas_func, stream, pointer_mode_host, 99 CUBLAS_DEFAULT_MATH, args...) 100 .ok(); 101 } 102 103 // A helper function to implement DoBlasGemmBatched interfaces for generic 104 // types. 105 template <typename T, typename Scalar, typename FuncT> 106 port::Status DoBlasGemmBatchedInternal( 107 FuncT cublas_func, Stream *stream, blas::Transpose transa, 108 blas::Transpose transb, uint64 m, uint64 n, uint64 k, Scalar alpha, 109 const port::ArraySlice<DeviceMemory<T> *> &a_array, int lda, 110 const port::ArraySlice<DeviceMemory<T> *> &b_array, int ldb, Scalar beta, 111 const port::ArraySlice<DeviceMemory<T> *> &c_array, int ldc, 112 int batch_count, ScratchAllocator *scratch_allocator); 113 114 // Helper function for implementing DoBlasGemmWithProfiling. 115 template <typename T, typename ParamType> 116 bool DoBlasGemmWithProfilingImpl( 117 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 118 uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a, 119 int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta, 120 DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result); 121 122 // Helper function for implementing DoBlasGemvWithProfiling. 123 template <typename T> 124 bool DoBlasGemvWithProfilingImpl(Stream *stream, blas::Transpose trans, 125 uint64 m, uint64 n, const T &alpha, 126 const DeviceMemory<T> &a, int lda, 127 const DeviceMemory<T> &x, int incx, 128 const T &beta, DeviceMemory<T> *y, int incy, 129 blas::ProfileResult *output_profile_result); 130 131 // Helper function for implementing DoBlasLtMatmul. 132 bool DoBlasLtMatmulInternal(Stream *stream, bool err_on_failure, 133 const blas::IBlasLtMatmulPlan *plan, 134 const HostOrDeviceScalar<void> &alpha, 135 DeviceMemoryBase a, DeviceMemoryBase b, 136 const HostOrDeviceScalar<void> &beta, 137 DeviceMemoryBase c, DeviceMemoryBase d, 138 ScratchAllocator *scratch_allocator, 139 const blas::IBlasLtMatmulAlgorithm *algorithm, 140 DeviceMemoryBase bias); 141 142 // Helper function for implementing GetBlasLtMatmulAlgorithms. 143 port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>> 144 GetBlasLtMatmulAlgorithmsInternal(const blas::IBlasLtMatmulPlan *plan, 145 size_t max_workspace_size, 146 int max_algorithm_count, 147 bool for_remainder_batch = false); 148 149 // Guards the cuBLAS handle for this device. 150 absl::Mutex mu_; 151 152 // GpuExecutor which instantiated this CUDABlas. 153 // Immutable post-initialization. 154 GpuExecutor *parent_; 155 156 // cuBLAS library handle on the device. 157 cublasHandle_t blas_ TF_GUARDED_BY(mu_); 158 159 #if CUDA_VERSION >= 11000 160 // cuBLASLt library handle on the device. 161 cublasLtHandle_t blasLt_ TF_GUARDED_BY(mu_); 162 #endif 163 164 SE_DISALLOW_COPY_AND_ASSIGN(CUDABlas); 165 }; 166 167 } // namespace gpu 168 } // namespace stream_executor 169 170 #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_ 171