• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // ROCM-specific support for BLAS functionality -- this wraps the rocBLAS
17 // library capabilities, and is only included into ROCM implementation code --
18 // it will not introduce rocm headers into other code.
19 
20 #ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_
21 #define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_
22 
23 #include "absl/synchronization/mutex.h"
24 #include "tensorflow/core/platform/thread_annotations.h"
25 #include "tensorflow/stream_executor/blas.h"
26 #include "tensorflow/stream_executor/platform/port.h"
27 #include "tensorflow/stream_executor/plugin_registry.h"
28 #include "tensorflow/stream_executor/temporary_device_memory.h"
29 
30 namespace stream_executor {
31 
32 class Stream;
33 
34 namespace gpu {
35 
36 // Type conversion helper that helps to map non-rocblas types to rocblas types
37 // Right now, it only converts the Eigen::half type to rocblas_half type
38 template <typename T>
39 struct RocBlasTypeConversionHelper {
40   using mapped_type = T;
41 };
42 
43 template <>
44 struct RocBlasTypeConversionHelper<Eigen::half> {
45   using mapped_type = rocblas_half;
46 };
47 
48 template <>
49 struct RocBlasTypeConversionHelper<std::complex<float>> {
50   using mapped_type = rocblas_float_complex;
51 };
52 
53 template <>
54 struct RocBlasTypeConversionHelper<std::complex<double>> {
55   using mapped_type = rocblas_double_complex;
56 };
57 
58 // Opaque and unique identifier for the rocBLAS plugin.
59 extern const PluginId kRocBlasPlugin;
60 
61 class GpuExecutor;
62 
63 // BLAS plugin for ROCM platform via rocBLAS library.
64 //
65 // This satisfies the platform-agnostic BlasSupport interface.
66 //
67 // Note that the rocBLAS handle that this encapsulates is implicitly tied to the
68 // context (and, as a result, the device) that the parent GpuExecutor is tied
69 // to. This simply happens as an artifact of creating the rocBLAS handle when a
70 // ROCM context is active.
71 //
72 // Thread-safe post-initialization.
73 class ROCMBlas : public blas::BlasSupport {
74  public:
75   explicit ROCMBlas(GpuExecutor *parent);
76 
77   // Allocates a rocBLAS handle.
78   bool Init();
79 
80   // Releases the rocBLAS handle, if present.
81   ~ROCMBlas() override;
82 
83   TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES
84 
85  private:
86   // Tells rocBLAS to enqueue the BLAS operation onto a particular Stream.
87   //
88   // rocBLAS is stateful, and only be associated with one stream (in order to
89   // enqueue dispatch) at a given time. As a result, this generally must be
90   // invoked before calling into rocBLAS.
91   bool SetStream(Stream *stream) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
92 
93   // A helper function that calls the real rocBLAS function together with error
94   // handling.
95   //
96   // rocblas_func:       rocBLAS function pointer.
97   // rocblas_name:       rocBLAS function name.
98   // stream:             Stream to enqueue the BLAS operation onto.
99   // pointer_mode_host:  Indicate if the pointer to a scalar value is from host
100   //                     (true) or device (false).
101   // err_on_failure:     Whether to print an error if the rocBLAS function
102   // fails. args:               Arguments of rocBLAS function.
103   template <typename FuncT, typename... Args>
104   bool DoBlasInternalImpl(FuncT rocblas_func, Stream *stream,
105                           bool pointer_mode_host, bool err_on_failure,
106                           Args... args);
107 
108   // Convenience functions that call DoBlasInternalImpl with different values
109   // for err_on_failure.
110   template <typename FuncT, typename... Args>
111   bool DoBlasInternal(FuncT rocblas_func, Stream *stream,
112                       bool pointer_mode_host, Args... args) {
113     return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host,
114                               /*err_on_failure=*/true, args...);
115   }
116   template <typename FuncT, typename... Args>
117   bool DoBlasInternalFailureOK(FuncT rocblas_func, Stream *stream,
118                                bool pointer_mode_host, Args... args) {
119     return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host,
120                               /*err_on_failure=*/false, args...);
121   }
122 
123   // A helper allocation function to convert raw pointers memory layout to
124   // strided flavor
125   template <typename T>
126   port::Status AllocateStridedBuffer(
127       const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type *>
128           &raw_ptrs,
129       int batch_count, uint64_t batch_stride,
130       ScratchAllocator *scratch_allocator, Stream *stream,
131       std::unique_ptr<TemporaryDeviceMemory<
132           typename RocBlasTypeConversionHelper<T>::mapped_type>> *temp_memory,
133       DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type>
134           *device_memory,
135       bool copy_data, bool &reallocated);
136 
137   // A helper function to implement DoBlasGemmBatched interfaces for generic
138   // types.
139   //
140   // Note: This function is implemented using gemm_strided_batched interface,
141   // NOT gemm_batched interface, because rocblas do not support it. As a
142   // result, if the passed in batch matrix are not allocated in strided batched
143   // format, it might end up in non-trivial amount of memory allocation and
144   // copy. To avoid this, always prioritize to use DoBlasGemmStridedBatched
145   // interface.
146   //
147   // In most use cases, batch matrix do get allocated in strided manner, making
148   // calling this interface equivalent with DoBlasGemmStridedBatched. The only
149   // use case we see so far that violates this observation is when batch
150   // matrix is created by broadcasting from a smaller matrix. When it happens,
151   // It will take advantage of the AllocateStridedBuffer subroutine to
152   // reallocate the memory layout to be strided batched.
153   template <typename T, typename FuncT>
154   port::Status DoBlasGemmBatchedInternal(
155       FuncT rocblas_func, Stream *stream, blas::Transpose transa,
156       blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha,
157       const port::ArraySlice<DeviceMemory<T> *> &a_ptrs_to_wrappers, int lda,
158       const port::ArraySlice<DeviceMemory<T> *> &b_ptrs_to_wrappers, int ldb,
159       T beta, const port::ArraySlice<DeviceMemory<T> *> &c_ptrs_to_wrappers,
160       int ldc, int batch_count, ScratchAllocator *scratch_allocator);
161 
162   // Helper function for implementing DoBlasGemmWithAlgorithm.
163   //
164   // We take alpha and beta by const reference because T might be Eigen::half,
165   // and we want to avoid pulling in a dependency on Eigen.  When we pass the
166   // references to rocBLAS, we essentially reinterpret_cast to __half, which is
167   // safe because Eigen::half inherits from __half.
168   template <typename InT, typename OutT, typename CompT>
169   bool DoBlasGemmWithAlgorithmImpl(
170       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
171       uint64 n, uint64 k, const CompT &alpha, const DeviceMemory<InT> &a,
172       int lda, const DeviceMemory<InT> &b, int ldb, const CompT &beta,
173       DeviceMemory<OutT> *c, int ldc, blas::ComputationType computation_type,
174       blas::AlgorithmType algorithm,
175       blas::ProfileResult *output_profile_result);
176 
177   // Helper function for implementing DoBlasGemmWithProfiling.
178   template <typename T, typename ParamType>
179   bool DoBlasGemmWithProfilingImpl(
180       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
181       uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a,
182       int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta,
183       DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result);
184 
185   // Helper function for implementing DoBlasGemvWithProfiling.
186   template <typename T>
187   bool DoBlasGemvWithProfilingImpl(Stream *stream, blas::Transpose trans,
188                                    uint64 m, uint64 n, const T &alpha,
189                                    const DeviceMemory<T> &a, int lda,
190                                    const DeviceMemory<T> &x, int incx,
191                                    const T &beta, DeviceMemory<T> *y, int incy,
192                                    blas::ProfileResult *output_profile_result);
193 
194   // mutex that guards the rocBLAS handle for this device.
195   absl::Mutex mu_;
196 
197   // GpuExecutor which instantiated this ROCMBlas.
198   // Immutable post-initialization.
199   GpuExecutor *parent_;
200 
201   // rocBLAS library handle on the device.
202   rocblas_handle blas_ TF_GUARDED_BY(mu_);
203 
204   SE_DISALLOW_COPY_AND_ASSIGN(ROCMBlas);
205 };
206 
207 }  // namespace gpu
208 }  // namespace stream_executor
209 
210 #endif  // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_
211