• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // CUDA-specific support for BLAS functionality -- this wraps the cuBLAS library
17 // capabilities, and is only included into CUDA implementation code -- it will
18 // not introduce cuda headers into other code.
19 
20 #ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_
21 #define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_
22 
23 #include "tensorflow/stream_executor/blas.h"
24 #include "tensorflow/stream_executor/host_or_device_scalar.h"
25 #include "tensorflow/stream_executor/platform/mutex.h"
26 #include "tensorflow/stream_executor/platform/port.h"
27 #include "tensorflow/stream_executor/platform/thread_annotations.h"
28 #include "tensorflow/stream_executor/plugin_registry.h"
29 
30 typedef struct cublasContext *cublasHandle_t;
31 
32 namespace stream_executor {
33 
34 class Stream;
35 
36 namespace gpu {
37 
38 // Opaque and unique identifier for the cuBLAS plugin.
39 extern const PluginId kCuBlasPlugin;
40 
41 class GpuExecutor;
42 
43 // BLAS plugin for CUDA platform via cuBLAS library.
44 //
45 // This satisfies the platform-agnostic BlasSupport interface.
46 //
47 // Note that the cuBLAS handle that this encapsulates is implicitly tied to the
48 // context (and, as a result, the device) that the parent GpuExecutor is tied
49 // to. This simply happens as an artifact of creating the cuBLAS handle when a
50 // CUDA context is active.
51 //
52 // Thread-safe post-initialization.
53 class CUDABlas : public blas::BlasSupport {
54  public:
55   explicit CUDABlas(GpuExecutor *parent);
56 
57   // Allocates a cuBLAS handle.
58   bool Init();
59 
60   // Releases the cuBLAS handle, if present.
61   ~CUDABlas() override;
62 
63   TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES
64 
65  private:
66   // Tells cuBLAS to enqueue the BLAS operation onto a particular Stream.
67   //
68   // cuBLAS is stateful, and only be associated with one stream (in order to
69   // enqueue dispatch) at a given time. As a result, this generally must be
70   // invoked before calling into cuBLAS.
71   bool SetStream(Stream *stream) EXCLUSIVE_LOCKS_REQUIRED(mu_);
72 
73   // A helper function that calls the real cuBLAS function together with error
74   // handling.
75   //
76   // cublas_func:        cuBLAS function pointer.
77   // cublas_name:        cuBLAS function name.
78   // stream:             Stream to enqueue the BLAS operation onto.
79   // pointer_mode_host:  Indicate if the pointer to a scalar value is from host
80   //                     (true) or device (false).
81   // err_on_failure:     Whether to print an error if the cublas function fails.
82   // args:               Arguments of cuBLAS function.
83   template <typename FuncT, typename... Args>
84   bool DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
85                           bool pointer_mode_host, bool err_on_failure,
86                           bool use_tensor_op_math, Args... args);
87 
88   // Convenience functions that call DoBlasInternalImpl with different values
89   // for err_on_failure.
90   template <typename FuncT, typename... Args>
DoBlasInternal(FuncT cublas_func,Stream * stream,bool pointer_mode_host,Args...args)91   bool DoBlasInternal(FuncT cublas_func, Stream *stream, bool pointer_mode_host,
92                       Args... args) {
93     return DoBlasInternalImpl(cublas_func, stream, pointer_mode_host,
94                               /*err_on_failure=*/true, /*use_tensor_ops=*/false,
95                               args...);
96   }
97   template <typename FuncT, typename... Args>
DoBlasInternalFailureOK(FuncT cublas_func,Stream * stream,bool pointer_mode_host,Args...args)98   bool DoBlasInternalFailureOK(FuncT cublas_func, Stream *stream,
99                                bool pointer_mode_host, Args... args) {
100     // Tensor ops are hard-coded off in this path, but can still be enabled with
101     // a specific algorithm choice as in DoBlasGemmWithAlgorithmImpl().
102     return DoBlasInternalImpl(cublas_func, stream, pointer_mode_host,
103                               /*err_on_failure=*/false,
104                               /*use_tensor_ops=*/false, args...);
105   }
106 
107   // A helper function to implement DoBlasGemmBatched interfaces for generic
108   // types.
109   template <typename T, typename Scalar, typename FuncT>
110   port::Status DoBlasGemmBatchedInternal(
111       FuncT cublas_func, Stream *stream, blas::Transpose transa,
112       blas::Transpose transb, uint64 m, uint64 n, uint64 k, Scalar alpha,
113       const port::ArraySlice<DeviceMemory<T> *> &a_array, int lda,
114       const port::ArraySlice<DeviceMemory<T> *> &b_array, int ldb, Scalar beta,
115       const port::ArraySlice<DeviceMemory<T> *> &c_array, int ldc,
116       int batch_count, ScratchAllocator *scratch_allocator);
117 
118   // Helper function for implementing DoBlasGemmWithAlgorithm.
119   template <typename InT, typename OutT, typename CompT>
120   bool DoBlasGemmWithAlgorithmImpl(
121       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
122       uint64 n, uint64 k, const HostOrDeviceScalar<CompT> &alpha,
123       const DeviceMemory<InT> &a, int lda, const DeviceMemory<InT> &b, int ldb,
124       const HostOrDeviceScalar<CompT> &beta, DeviceMemory<OutT> *c, int ldc,
125       blas::ComputationType computation_type, blas::AlgorithmType algorithm,
126       blas::ProfileResult *output_profile_result);
127 
128   // Helper function for implementing DoBlasGemmWithProfiling.
129   template <typename T, typename ParamType>
130   bool DoBlasGemmWithProfilingImpl(
131       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
132       uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a,
133       int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta,
134       DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result);
135 
136   // Helper function for implementing DoBlasGemvWithProfiling.
137   template <typename T>
138   bool DoBlasGemvWithProfilingImpl(Stream *stream, blas::Transpose trans,
139                                    uint64 m, uint64 n, const T &alpha,
140                                    const DeviceMemory<T> &a, int lda,
141                                    const DeviceMemory<T> &x, int incx,
142                                    const T &beta, DeviceMemory<T> *y, int incy,
143                                    blas::ProfileResult *output_profile_result);
144 
145   // mutex that guards the cuBLAS handle for this device.
146   mutex mu_;
147 
148   // GpuExecutor which instantiated this CUDABlas.
149   // Immutable post-initialization.
150   GpuExecutor *parent_;
151 
152   // cuBLAS library handle on the device.
153   cublasHandle_t blas_ GUARDED_BY(mu_);
154 
155   SE_DISALLOW_COPY_AND_ASSIGN(CUDABlas);
156 };
157 
158 }  // namespace gpu
159 }  // namespace stream_executor
160 
161 #endif  // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_
162