1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // CUDA-specific support for BLAS functionality -- this wraps the cuBLAS library 17 // capabilities, and is only included into CUDA implementation code -- it will 18 // not introduce cuda headers into other code. 19 20 #ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_ 21 #define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_ 22 23 #include "tensorflow/stream_executor/blas.h" 24 #include "tensorflow/stream_executor/host_or_device_scalar.h" 25 #include "tensorflow/stream_executor/platform/mutex.h" 26 #include "tensorflow/stream_executor/platform/port.h" 27 #include "tensorflow/stream_executor/platform/thread_annotations.h" 28 #include "tensorflow/stream_executor/plugin_registry.h" 29 30 typedef struct cublasContext *cublasHandle_t; 31 32 namespace stream_executor { 33 34 class Stream; 35 36 namespace gpu { 37 38 // Opaque and unique identifier for the cuBLAS plugin. 39 extern const PluginId kCuBlasPlugin; 40 41 class GpuExecutor; 42 43 // BLAS plugin for CUDA platform via cuBLAS library. 44 // 45 // This satisfies the platform-agnostic BlasSupport interface. 46 // 47 // Note that the cuBLAS handle that this encapsulates is implicitly tied to the 48 // context (and, as a result, the device) that the parent GpuExecutor is tied 49 // to. This simply happens as an artifact of creating the cuBLAS handle when a 50 // CUDA context is active. 51 // 52 // Thread-safe post-initialization. 53 class CUDABlas : public blas::BlasSupport { 54 public: 55 explicit CUDABlas(GpuExecutor *parent); 56 57 // Allocates a cuBLAS handle. 58 bool Init(); 59 60 // Releases the cuBLAS handle, if present. 61 ~CUDABlas() override; 62 63 TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES 64 65 private: 66 // Tells cuBLAS to enqueue the BLAS operation onto a particular Stream. 67 // 68 // cuBLAS is stateful, and only be associated with one stream (in order to 69 // enqueue dispatch) at a given time. As a result, this generally must be 70 // invoked before calling into cuBLAS. 71 bool SetStream(Stream *stream) EXCLUSIVE_LOCKS_REQUIRED(mu_); 72 73 // A helper function that calls the real cuBLAS function together with error 74 // handling. 75 // 76 // cublas_func: cuBLAS function pointer. 77 // cublas_name: cuBLAS function name. 78 // stream: Stream to enqueue the BLAS operation onto. 79 // pointer_mode_host: Indicate if the pointer to a scalar value is from host 80 // (true) or device (false). 81 // err_on_failure: Whether to print an error if the cublas function fails. 82 // args: Arguments of cuBLAS function. 83 template <typename FuncT, typename... Args> 84 bool DoBlasInternalImpl(FuncT cublas_func, Stream *stream, 85 bool pointer_mode_host, bool err_on_failure, 86 bool use_tensor_op_math, Args... args); 87 88 // Convenience functions that call DoBlasInternalImpl with different values 89 // for err_on_failure. 90 template <typename FuncT, typename... Args> DoBlasInternal(FuncT cublas_func,Stream * stream,bool pointer_mode_host,Args...args)91 bool DoBlasInternal(FuncT cublas_func, Stream *stream, bool pointer_mode_host, 92 Args... args) { 93 return DoBlasInternalImpl(cublas_func, stream, pointer_mode_host, 94 /*err_on_failure=*/true, /*use_tensor_ops=*/false, 95 args...); 96 } 97 template <typename FuncT, typename... Args> DoBlasInternalFailureOK(FuncT cublas_func,Stream * stream,bool pointer_mode_host,Args...args)98 bool DoBlasInternalFailureOK(FuncT cublas_func, Stream *stream, 99 bool pointer_mode_host, Args... args) { 100 // Tensor ops are hard-coded off in this path, but can still be enabled with 101 // a specific algorithm choice as in DoBlasGemmWithAlgorithmImpl(). 102 return DoBlasInternalImpl(cublas_func, stream, pointer_mode_host, 103 /*err_on_failure=*/false, 104 /*use_tensor_ops=*/false, args...); 105 } 106 107 // A helper function to implement DoBlasGemmBatched interfaces for generic 108 // types. 109 template <typename T, typename Scalar, typename FuncT> 110 port::Status DoBlasGemmBatchedInternal( 111 FuncT cublas_func, Stream *stream, blas::Transpose transa, 112 blas::Transpose transb, uint64 m, uint64 n, uint64 k, Scalar alpha, 113 const port::ArraySlice<DeviceMemory<T> *> &a_array, int lda, 114 const port::ArraySlice<DeviceMemory<T> *> &b_array, int ldb, Scalar beta, 115 const port::ArraySlice<DeviceMemory<T> *> &c_array, int ldc, 116 int batch_count, ScratchAllocator *scratch_allocator); 117 118 // Helper function for implementing DoBlasGemmWithAlgorithm. 119 template <typename InT, typename OutT, typename CompT> 120 bool DoBlasGemmWithAlgorithmImpl( 121 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 122 uint64 n, uint64 k, const HostOrDeviceScalar<CompT> &alpha, 123 const DeviceMemory<InT> &a, int lda, const DeviceMemory<InT> &b, int ldb, 124 const HostOrDeviceScalar<CompT> &beta, DeviceMemory<OutT> *c, int ldc, 125 blas::ComputationType computation_type, blas::AlgorithmType algorithm, 126 blas::ProfileResult *output_profile_result); 127 128 // Helper function for implementing DoBlasGemmWithProfiling. 129 template <typename T, typename ParamType> 130 bool DoBlasGemmWithProfilingImpl( 131 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 132 uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a, 133 int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta, 134 DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result); 135 136 // Helper function for implementing DoBlasGemvWithProfiling. 137 template <typename T> 138 bool DoBlasGemvWithProfilingImpl(Stream *stream, blas::Transpose trans, 139 uint64 m, uint64 n, const T &alpha, 140 const DeviceMemory<T> &a, int lda, 141 const DeviceMemory<T> &x, int incx, 142 const T &beta, DeviceMemory<T> *y, int incy, 143 blas::ProfileResult *output_profile_result); 144 145 // mutex that guards the cuBLAS handle for this device. 146 mutex mu_; 147 148 // GpuExecutor which instantiated this CUDABlas. 149 // Immutable post-initialization. 150 GpuExecutor *parent_; 151 152 // cuBLAS library handle on the device. 153 cublasHandle_t blas_ GUARDED_BY(mu_); 154 155 SE_DISALLOW_COPY_AND_ASSIGN(CUDABlas); 156 }; 157 158 } // namespace gpu 159 } // namespace stream_executor 160 161 #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_ 162