1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // CUDA-specific support for BLAS functionality -- this wraps the cuBLAS library 17 // capabilities, and is only included into CUDA implementation code -- it will 18 // not introduce cuda headers into other code. 19 20 #ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_ 21 #define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_ 22 23 #include "absl/synchronization/mutex.h" 24 #include "third_party/gpus/cuda/include/cublasLt.h" 25 #include "third_party/gpus/cuda/include/cublas_v2.h" 26 #include "third_party/gpus/cuda/include/cuda.h" 27 #include "tensorflow/core/platform/thread_annotations.h" 28 #include "tensorflow/stream_executor/blas.h" 29 #include "tensorflow/stream_executor/host_or_device_scalar.h" 30 #include "tensorflow/stream_executor/platform/port.h" 31 #include "tensorflow/stream_executor/plugin_registry.h" 32 33 typedef struct cublasContext *cublasHandle_t; 34 35 namespace stream_executor { 36 37 class Stream; 38 39 namespace gpu { 40 41 // Opaque and unique identifier for the cuBLAS plugin. 42 extern const PluginId kCuBlasPlugin; 43 44 class GpuExecutor; 45 46 // BLAS plugin for CUDA platform via cuBLAS library. 47 // 48 // This satisfies the platform-agnostic BlasSupport interface. 49 // 50 // Note that the cuBLAS handle that this encapsulates is implicitly tied to the 51 // context (and, as a result, the device) that the parent GpuExecutor is tied 52 // to. This simply happens as an artifact of creating the cuBLAS handle when a 53 // CUDA context is active. 54 // 55 // Thread-safe post-initialization. 56 class CUDABlas : public blas::BlasSupport { 57 public: 58 explicit CUDABlas(GpuExecutor *parent); 59 60 // Allocates a cuBLAS handle. 61 bool Init(); 62 63 // Releases the cuBLAS handle, if present. 64 ~CUDABlas() override; 65 66 TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES 67 68 private: 69 // Tells cuBLAS to enqueue the BLAS operation onto a particular Stream. 70 // 71 // cuBLAS is stateful, and only be associated with one stream (in order to 72 // enqueue dispatch) at a given time. As a result, this generally must be 73 // invoked before calling into cuBLAS. 74 bool SetStream(Stream *stream) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 75 76 // Returns the underlying CUDA stream. 77 cudaStream_t CUDAStream(Stream *stream); 78 79 // A helper function that calls the real cuBLAS function together with error 80 // handling. 81 // 82 // cublas_func: cuBLAS function pointer. 83 // cublas_name: cuBLAS function name. 84 // stream: Stream to enqueue the BLAS operation onto. 85 // pointer_mode_host: Indicate if the pointer to a scalar value is from host 86 // (true) or device (false). 87 // err_on_failure: Whether to print an error if the cublas function fails. 88 // args: Arguments of cuBLAS function. 89 template <typename FuncT, typename... Args> 90 bool DoBlasInternalImpl(FuncT cublas_func, Stream *stream, 91 bool pointer_mode_host, bool err_on_failure, 92 cublasMath_t math_type, Args... args); 93 94 // Convenience functions that call DoBlasInternalImpl with err_on_failure=true 95 // and math_type=CUBLAS_DEFAULT_MATH. 96 template <typename FuncT, typename... Args> DoBlasInternal(FuncT cublas_func,Stream * stream,bool pointer_mode_host,Args...args)97 bool DoBlasInternal(FuncT cublas_func, Stream *stream, bool pointer_mode_host, 98 Args... args) { 99 return DoBlasInternalImpl(cublas_func, stream, pointer_mode_host, 100 /*err_on_failure=*/true, CUBLAS_DEFAULT_MATH, 101 args...); 102 } 103 104 // A helper function to implement DoBlasGemmBatched interfaces for generic 105 // types. 106 template <typename T, typename Scalar, typename FuncT> 107 port::Status DoBlasGemmBatchedInternal( 108 FuncT cublas_func, Stream *stream, blas::Transpose transa, 109 blas::Transpose transb, uint64 m, uint64 n, uint64 k, Scalar alpha, 110 const port::ArraySlice<DeviceMemory<T> *> &a_array, int lda, 111 const port::ArraySlice<DeviceMemory<T> *> &b_array, int ldb, Scalar beta, 112 const port::ArraySlice<DeviceMemory<T> *> &c_array, int ldc, 113 int batch_count, ScratchAllocator *scratch_allocator); 114 115 // Helper function for implementing DoBlasGemmWithAlgorithm. 116 template <typename InT, typename OutT, typename CompT> 117 bool DoBlasGemmWithAlgorithmImpl( 118 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 119 uint64 n, uint64 k, const HostOrDeviceScalar<CompT> &alpha, 120 const DeviceMemory<InT> &a, int lda, const DeviceMemory<InT> &b, int ldb, 121 const HostOrDeviceScalar<CompT> &beta, DeviceMemory<OutT> *c, int ldc, 122 blas::ComputationType computation_type, blas::AlgorithmType algorithm, 123 blas::ProfileResult *output_profile_result); 124 125 // Helper function for implementing DoBlasGemmWithProfiling. 126 template <typename T, typename ParamType> 127 bool DoBlasGemmWithProfilingImpl( 128 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 129 uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a, 130 int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta, 131 DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result); 132 133 // Helper function for implementing DoBlasGemvWithProfiling. 134 template <typename T> 135 bool DoBlasGemvWithProfilingImpl(Stream *stream, blas::Transpose trans, 136 uint64 m, uint64 n, const T &alpha, 137 const DeviceMemory<T> &a, int lda, 138 const DeviceMemory<T> &x, int incx, 139 const T &beta, DeviceMemory<T> *y, int incy, 140 blas::ProfileResult *output_profile_result); 141 142 // Helper function for implementing DoBlasLtMatmul. 143 bool DoBlasLtMatmulInternal(Stream *stream, bool err_on_failure, 144 const blas::IBlasLtMatmulPlan *plan, 145 const HostOrDeviceScalar<void> &alpha, 146 DeviceMemoryBase a, DeviceMemoryBase b, 147 const HostOrDeviceScalar<void> &beta, 148 DeviceMemoryBase c, DeviceMemoryBase d, 149 ScratchAllocator *scratch_allocator, 150 const blas::IBlasLtMatmulAlgorithm *algorithm, 151 DeviceMemoryBase bias); 152 153 // Helper function for implementing GetBlasLtMatmulAlgorithms. 154 port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>> 155 GetBlasLtMatmulAlgorithmsInternal(const blas::IBlasLtMatmulPlan *plan, 156 size_t max_workspace_size, 157 int max_algorithm_count, 158 bool for_remainder_batch = false); 159 160 // Guards the cuBLAS handle for this device. 161 absl::Mutex mu_; 162 163 // GpuExecutor which instantiated this CUDABlas. 164 // Immutable post-initialization. 165 GpuExecutor *parent_; 166 167 // cuBLAS library handle on the device. 168 cublasHandle_t blas_ TF_GUARDED_BY(mu_); 169 170 #if CUDA_VERSION >= 11000 171 // cuBLASLt library handle on the device. 172 cublasLtHandle_t blasLt_ GUARDED_BY(mu_); 173 #endif 174 175 SE_DISALLOW_COPY_AND_ASSIGN(CUDABlas); 176 }; 177 178 } // namespace gpu 179 } // namespace stream_executor 180 181 #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_ 182