• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // ROCM-specific support for BLAS functionality -- this wraps the rocBLAS
17 // library capabilities, and is only included into ROCM implementation code --
18 // it will not introduce rocm headers into other code.
19 
20 #ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_
21 #define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_
22 
23 #include "absl/synchronization/mutex.h"
24 #include "rocm/include/rocblas.h"
25 #include "tensorflow/core/platform/thread_annotations.h"
26 #include "tensorflow/stream_executor/blas.h"
27 #include "tensorflow/stream_executor/platform/port.h"
28 #include "tensorflow/stream_executor/plugin_registry.h"
29 #include "tensorflow/stream_executor/temporary_device_memory.h"
30 
31 namespace stream_executor {
32 
33 class Stream;
34 
35 namespace gpu {
36 
37 // Type conversion helper that helps to map non-rocblas types to rocblas types
38 // Right now, it only converts the Eigen::half type to rocblas_half type
39 template <typename T>
40 struct RocBlasTypeConversionHelper {
41   using mapped_type = T;
42 };
43 
44 template <>
45 struct RocBlasTypeConversionHelper<Eigen::half> {
46   using mapped_type = rocblas_half;
47 };
48 
49 template <>
50 struct RocBlasTypeConversionHelper<std::complex<float>> {
51   using mapped_type = rocblas_float_complex;
52 };
53 
54 template <>
55 struct RocBlasTypeConversionHelper<std::complex<double>> {
56   using mapped_type = rocblas_double_complex;
57 };
58 
59 // Opaque and unique identifier for the rocBLAS plugin.
60 extern const PluginId kRocBlasPlugin;
61 
62 class GpuExecutor;
63 
64 // BLAS plugin for ROCM platform via rocBLAS library.
65 //
66 // This satisfies the platform-agnostic BlasSupport interface.
67 //
68 // Note that the rocBLAS handle that this encapsulates is implicitly tied to the
69 // context (and, as a result, the device) that the parent GpuExecutor is tied
70 // to. This simply happens as an artifact of creating the rocBLAS handle when a
71 // ROCM context is active.
72 //
73 // Thread-safe post-initialization.
74 class ROCMBlas : public blas::BlasSupport {
75  public:
76   explicit ROCMBlas(GpuExecutor *parent);
77 
78   // Allocates a rocBLAS handle.
79   bool Init();
80 
81   // Releases the rocBLAS handle, if present.
82   ~ROCMBlas() override;
83 
84   TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES
85 
86  private:
87   // Tells rocBLAS to enqueue the BLAS operation onto a particular Stream.
88   //
89   // rocBLAS is stateful, and only be associated with one stream (in order to
90   // enqueue dispatch) at a given time. As a result, this generally must be
91   // invoked before calling into rocBLAS.
92   bool SetStream(Stream *stream) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
93 
94   // A helper function that calls the real rocBLAS function together with error
95   // handling.
96   //
97   // rocblas_func:       rocBLAS function pointer.
98   // rocblas_name:       rocBLAS function name.
99   // stream:             Stream to enqueue the BLAS operation onto.
100   // pointer_mode_host:  Indicate if the pointer to a scalar value is from host
101   //                     (true) or device (false).
102   // err_on_failure:     Whether to print an error if the rocBLAS function
103   // fails. args:               Arguments of rocBLAS function.
104   template <typename FuncT, typename... Args>
105   bool DoBlasInternalImpl(FuncT rocblas_func, Stream *stream,
106                           bool pointer_mode_host, bool err_on_failure,
107                           Args... args);
108 
109   // Convenience functions that call DoBlasInternalImpl with different values
110   // for err_on_failure.
111   template <typename FuncT, typename... Args>
112   bool DoBlasInternal(FuncT rocblas_func, Stream *stream,
113                       bool pointer_mode_host, Args... args) {
114     return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host,
115                               /*err_on_failure=*/true, args...);
116   }
117 
118   // Same as above, but returns Status.
119   template <typename... Args>
120   port::Status DoBlasInternalStatus(Args... args) {
121     if (!DoBlasInternal(args...)) {
122       return port::InternalError("Failed calling rocBLAS");
123     }
124     return port::Status::OK();
125   }
126 
127   template <typename FuncT, typename... Args>
128   bool DoBlasInternalFailureOK(FuncT rocblas_func, Stream *stream,
129                                bool pointer_mode_host, Args... args) {
130     return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host,
131                               /*err_on_failure=*/false, args...);
132   }
133 
134   // A helper allocation function to convert raw pointers memory layout to
135   // strided flavor
136   template <typename T>
137   port::Status AllocateStridedBuffer(
138       const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type *>
139           &raw_ptrs,
140       int batch_count, uint64_t batch_stride,
141       ScratchAllocator *scratch_allocator, Stream *stream,
142       std::unique_ptr<TemporaryDeviceMemory<
143           typename RocBlasTypeConversionHelper<T>::mapped_type>> *temp_memory,
144       DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type>
145           *device_memory,
146       bool copy_data, bool &reallocated);
147 
148   // A helper function to implement DoBlasGemmBatched interfaces for generic
149   // types.
150   //
151   // Note: This function is implemented using gemm_strided_batched interface,
152   // NOT gemm_batched interface, because rocblas do not support it. As a
153   // result, if the passed in batch matrix are not allocated in strided batched
154   // format, it might end up in non-trivial amount of memory allocation and
155   // copy. To avoid this, always prioritize to use DoBlasGemmStridedBatched
156   // interface.
157   //
158   // In most use cases, batch matrix do get allocated in strided manner, making
159   // calling this interface equivalent with DoBlasGemmStridedBatched. The only
160   // use case we see so far that violates this observation is when batch
161   // matrix is created by broadcasting from a smaller matrix. When it happens,
162   // It will take advantage of the AllocateStridedBuffer subroutine to
163   // reallocate the memory layout to be strided batched.
164   template <typename T, typename FuncT>
165   port::Status DoBlasGemmBatchedInternal(
166       FuncT rocblas_func, Stream *stream, blas::Transpose transa,
167       blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha,
168       const port::ArraySlice<DeviceMemory<T> *> &a_ptrs_to_wrappers, int lda,
169       const port::ArraySlice<DeviceMemory<T> *> &b_ptrs_to_wrappers, int ldb,
170       T beta, const port::ArraySlice<DeviceMemory<T> *> &c_ptrs_to_wrappers,
171       int ldc, int batch_count, ScratchAllocator *scratch_allocator);
172 
173   // Helper function for implementing DoBlasGemmWithProfiling.
174   template <typename T, typename ParamType>
175   bool DoBlasGemmWithProfilingImpl(
176       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
177       uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a,
178       int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta,
179       DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result);
180 
181   // Helper function for implementing DoBlasGemvWithProfiling.
182   template <typename T>
183   bool DoBlasGemvWithProfilingImpl(Stream *stream, blas::Transpose trans,
184                                    uint64 m, uint64 n, const T &alpha,
185                                    const DeviceMemory<T> &a, int lda,
186                                    const DeviceMemory<T> &x, int incx,
187                                    const T &beta, DeviceMemory<T> *y, int incy,
188                                    blas::ProfileResult *output_profile_result);
189 
190   // mutex that guards the rocBLAS handle for this device.
191   absl::Mutex mu_;
192 
193   // GpuExecutor which instantiated this ROCMBlas.
194   // Immutable post-initialization.
195   GpuExecutor *parent_;
196 
197   // rocBLAS library handle on the device.
198   rocblas_handle blas_ TF_GUARDED_BY(mu_);
199 
200   SE_DISALLOW_COPY_AND_ASSIGN(ROCMBlas);
201 };
202 
203 }  // namespace gpu
204 }  // namespace stream_executor
205 
206 #endif  // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_
207