1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // ROCM-specific support for BLAS functionality -- this wraps the rocBLAS 17 // library capabilities, and is only included into ROCM implementation code -- 18 // it will not introduce rocm headers into other code. 19 20 #ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_ 21 #define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_ 22 23 #include "absl/synchronization/mutex.h" 24 #include "rocm/include/rocblas.h" 25 #include "tensorflow/core/platform/thread_annotations.h" 26 #include "tensorflow/stream_executor/blas.h" 27 #include "tensorflow/stream_executor/platform/port.h" 28 #include "tensorflow/stream_executor/plugin_registry.h" 29 #include "tensorflow/stream_executor/temporary_device_memory.h" 30 31 namespace stream_executor { 32 33 class Stream; 34 35 namespace gpu { 36 37 // Type conversion helper that helps to map non-rocblas types to rocblas types 38 // Right now, it only converts the Eigen::half type to rocblas_half type 39 template <typename T> 40 struct RocBlasTypeConversionHelper { 41 using mapped_type = T; 42 }; 43 44 template <> 45 struct RocBlasTypeConversionHelper<Eigen::half> { 46 using mapped_type = rocblas_half; 47 }; 48 49 template <> 50 struct RocBlasTypeConversionHelper<std::complex<float>> { 51 using mapped_type = rocblas_float_complex; 52 }; 53 54 template <> 55 struct RocBlasTypeConversionHelper<std::complex<double>> { 56 using mapped_type = rocblas_double_complex; 57 }; 58 59 // Opaque and unique identifier for the rocBLAS plugin. 60 extern const PluginId kRocBlasPlugin; 61 62 class GpuExecutor; 63 64 // BLAS plugin for ROCM platform via rocBLAS library. 65 // 66 // This satisfies the platform-agnostic BlasSupport interface. 67 // 68 // Note that the rocBLAS handle that this encapsulates is implicitly tied to the 69 // context (and, as a result, the device) that the parent GpuExecutor is tied 70 // to. This simply happens as an artifact of creating the rocBLAS handle when a 71 // ROCM context is active. 72 // 73 // Thread-safe post-initialization. 74 class ROCMBlas : public blas::BlasSupport { 75 public: 76 explicit ROCMBlas(GpuExecutor *parent); 77 78 // Allocates a rocBLAS handle. 79 bool Init(); 80 81 // Releases the rocBLAS handle, if present. 82 ~ROCMBlas() override; 83 84 TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES 85 86 private: 87 // Tells rocBLAS to enqueue the BLAS operation onto a particular Stream. 88 // 89 // rocBLAS is stateful, and only be associated with one stream (in order to 90 // enqueue dispatch) at a given time. As a result, this generally must be 91 // invoked before calling into rocBLAS. 92 bool SetStream(Stream *stream) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 93 94 // A helper function that calls the real rocBLAS function together with error 95 // handling. 96 // 97 // rocblas_func: rocBLAS function pointer. 98 // rocblas_name: rocBLAS function name. 99 // stream: Stream to enqueue the BLAS operation onto. 100 // pointer_mode_host: Indicate if the pointer to a scalar value is from host 101 // (true) or device (false). 102 // err_on_failure: Whether to print an error if the rocBLAS function 103 // fails. args: Arguments of rocBLAS function. 104 template <typename FuncT, typename... Args> 105 bool DoBlasInternalImpl(FuncT rocblas_func, Stream *stream, 106 bool pointer_mode_host, bool err_on_failure, 107 Args... args); 108 109 // Convenience functions that call DoBlasInternalImpl with different values 110 // for err_on_failure. 111 template <typename FuncT, typename... Args> 112 bool DoBlasInternal(FuncT rocblas_func, Stream *stream, 113 bool pointer_mode_host, Args... args) { 114 return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host, 115 /*err_on_failure=*/true, args...); 116 } 117 118 // Same as above, but returns Status. 119 template <typename... Args> 120 port::Status DoBlasInternalStatus(Args... args) { 121 if (!DoBlasInternal(args...)) { 122 return port::InternalError("Failed calling rocBLAS"); 123 } 124 return port::Status::OK(); 125 } 126 127 template <typename FuncT, typename... Args> 128 bool DoBlasInternalFailureOK(FuncT rocblas_func, Stream *stream, 129 bool pointer_mode_host, Args... args) { 130 return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host, 131 /*err_on_failure=*/false, args...); 132 } 133 134 // A helper allocation function to convert raw pointers memory layout to 135 // strided flavor 136 template <typename T> 137 port::Status AllocateStridedBuffer( 138 const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type *> 139 &raw_ptrs, 140 int batch_count, uint64_t batch_stride, 141 ScratchAllocator *scratch_allocator, Stream *stream, 142 std::unique_ptr<TemporaryDeviceMemory< 143 typename RocBlasTypeConversionHelper<T>::mapped_type>> *temp_memory, 144 DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type> 145 *device_memory, 146 bool copy_data, bool &reallocated); 147 148 // A helper function to implement DoBlasGemmBatched interfaces for generic 149 // types. 150 // 151 // Note: This function is implemented using gemm_strided_batched interface, 152 // NOT gemm_batched interface, because rocblas do not support it. As a 153 // result, if the passed in batch matrix are not allocated in strided batched 154 // format, it might end up in non-trivial amount of memory allocation and 155 // copy. To avoid this, always prioritize to use DoBlasGemmStridedBatched 156 // interface. 157 // 158 // In most use cases, batch matrix do get allocated in strided manner, making 159 // calling this interface equivalent with DoBlasGemmStridedBatched. The only 160 // use case we see so far that violates this observation is when batch 161 // matrix is created by broadcasting from a smaller matrix. When it happens, 162 // It will take advantage of the AllocateStridedBuffer subroutine to 163 // reallocate the memory layout to be strided batched. 164 template <typename T, typename FuncT> 165 port::Status DoBlasGemmBatchedInternal( 166 FuncT rocblas_func, Stream *stream, blas::Transpose transa, 167 blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha, 168 const port::ArraySlice<DeviceMemory<T> *> &a_ptrs_to_wrappers, int lda, 169 const port::ArraySlice<DeviceMemory<T> *> &b_ptrs_to_wrappers, int ldb, 170 T beta, const port::ArraySlice<DeviceMemory<T> *> &c_ptrs_to_wrappers, 171 int ldc, int batch_count, ScratchAllocator *scratch_allocator); 172 173 // Helper function for implementing DoBlasGemmWithProfiling. 174 template <typename T, typename ParamType> 175 bool DoBlasGemmWithProfilingImpl( 176 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 177 uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a, 178 int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta, 179 DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result); 180 181 // Helper function for implementing DoBlasGemvWithProfiling. 182 template <typename T> 183 bool DoBlasGemvWithProfilingImpl(Stream *stream, blas::Transpose trans, 184 uint64 m, uint64 n, const T &alpha, 185 const DeviceMemory<T> &a, int lda, 186 const DeviceMemory<T> &x, int incx, 187 const T &beta, DeviceMemory<T> *y, int incy, 188 blas::ProfileResult *output_profile_result); 189 190 // mutex that guards the rocBLAS handle for this device. 191 absl::Mutex mu_; 192 193 // GpuExecutor which instantiated this ROCMBlas. 194 // Immutable post-initialization. 195 GpuExecutor *parent_; 196 197 // rocBLAS library handle on the device. 198 rocblas_handle blas_ TF_GUARDED_BY(mu_); 199 200 SE_DISALLOW_COPY_AND_ASSIGN(ROCMBlas); 201 }; 202 203 } // namespace gpu 204 } // namespace stream_executor 205 206 #endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_ 207