1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // ROCM-specific support for BLAS functionality -- this wraps the rocBLAS 17 // library capabilities, and is only included into ROCM implementation code -- 18 // it will not introduce rocm headers into other code. 19 20 #ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_ 21 #define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_ 22 23 #include "absl/synchronization/mutex.h" 24 #include "tensorflow/core/platform/thread_annotations.h" 25 #include "tensorflow/stream_executor/blas.h" 26 #include "tensorflow/stream_executor/platform/port.h" 27 #include "tensorflow/stream_executor/plugin_registry.h" 28 #include "tensorflow/stream_executor/temporary_device_memory.h" 29 30 namespace stream_executor { 31 32 class Stream; 33 34 namespace gpu { 35 36 // Type conversion helper that helps to map non-rocblas types to rocblas types 37 // Right now, it only converts the Eigen::half type to rocblas_half type 38 template <typename T> 39 struct RocBlasTypeConversionHelper { 40 using mapped_type = T; 41 }; 42 43 template <> 44 struct RocBlasTypeConversionHelper<Eigen::half> { 45 using mapped_type = rocblas_half; 46 }; 47 48 template <> 49 struct RocBlasTypeConversionHelper<std::complex<float>> { 50 using mapped_type = rocblas_float_complex; 51 }; 52 53 template <> 54 struct RocBlasTypeConversionHelper<std::complex<double>> { 55 using mapped_type = rocblas_double_complex; 56 }; 57 58 // Opaque and unique identifier for the rocBLAS plugin. 59 extern const PluginId kRocBlasPlugin; 60 61 class GpuExecutor; 62 63 // BLAS plugin for ROCM platform via rocBLAS library. 64 // 65 // This satisfies the platform-agnostic BlasSupport interface. 66 // 67 // Note that the rocBLAS handle that this encapsulates is implicitly tied to the 68 // context (and, as a result, the device) that the parent GpuExecutor is tied 69 // to. This simply happens as an artifact of creating the rocBLAS handle when a 70 // ROCM context is active. 71 // 72 // Thread-safe post-initialization. 73 class ROCMBlas : public blas::BlasSupport { 74 public: 75 explicit ROCMBlas(GpuExecutor *parent); 76 77 // Allocates a rocBLAS handle. 78 bool Init(); 79 80 // Releases the rocBLAS handle, if present. 81 ~ROCMBlas() override; 82 83 TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES 84 85 private: 86 // Tells rocBLAS to enqueue the BLAS operation onto a particular Stream. 87 // 88 // rocBLAS is stateful, and only be associated with one stream (in order to 89 // enqueue dispatch) at a given time. As a result, this generally must be 90 // invoked before calling into rocBLAS. 91 bool SetStream(Stream *stream) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 92 93 // A helper function that calls the real rocBLAS function together with error 94 // handling. 95 // 96 // rocblas_func: rocBLAS function pointer. 97 // rocblas_name: rocBLAS function name. 98 // stream: Stream to enqueue the BLAS operation onto. 99 // pointer_mode_host: Indicate if the pointer to a scalar value is from host 100 // (true) or device (false). 101 // err_on_failure: Whether to print an error if the rocBLAS function 102 // fails. args: Arguments of rocBLAS function. 103 template <typename FuncT, typename... Args> 104 bool DoBlasInternalImpl(FuncT rocblas_func, Stream *stream, 105 bool pointer_mode_host, bool err_on_failure, 106 Args... args); 107 108 // Convenience functions that call DoBlasInternalImpl with different values 109 // for err_on_failure. 110 template <typename FuncT, typename... Args> 111 bool DoBlasInternal(FuncT rocblas_func, Stream *stream, 112 bool pointer_mode_host, Args... args) { 113 return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host, 114 /*err_on_failure=*/true, args...); 115 } 116 template <typename FuncT, typename... Args> 117 bool DoBlasInternalFailureOK(FuncT rocblas_func, Stream *stream, 118 bool pointer_mode_host, Args... args) { 119 return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host, 120 /*err_on_failure=*/false, args...); 121 } 122 123 // A helper allocation function to convert raw pointers memory layout to 124 // strided flavor 125 template <typename T> 126 port::Status AllocateStridedBuffer( 127 const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type *> 128 &raw_ptrs, 129 int batch_count, uint64_t batch_stride, 130 ScratchAllocator *scratch_allocator, Stream *stream, 131 std::unique_ptr<TemporaryDeviceMemory< 132 typename RocBlasTypeConversionHelper<T>::mapped_type>> *temp_memory, 133 DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type> 134 *device_memory, 135 bool copy_data, bool &reallocated); 136 137 // A helper function to implement DoBlasGemmBatched interfaces for generic 138 // types. 139 // 140 // Note: This function is implemented using gemm_strided_batched interface, 141 // NOT gemm_batched interface, because rocblas do not support it. As a 142 // result, if the passed in batch matrix are not allocated in strided batched 143 // format, it might end up in non-trivial amount of memory allocation and 144 // copy. To avoid this, always prioritize to use DoBlasGemmStridedBatched 145 // interface. 146 // 147 // In most use cases, batch matrix do get allocated in strided manner, making 148 // calling this interface equivalent with DoBlasGemmStridedBatched. The only 149 // use case we see so far that violates this observation is when batch 150 // matrix is created by broadcasting from a smaller matrix. When it happens, 151 // It will take advantage of the AllocateStridedBuffer subroutine to 152 // reallocate the memory layout to be strided batched. 153 template <typename T, typename FuncT> 154 port::Status DoBlasGemmBatchedInternal( 155 FuncT rocblas_func, Stream *stream, blas::Transpose transa, 156 blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha, 157 const port::ArraySlice<DeviceMemory<T> *> &a_ptrs_to_wrappers, int lda, 158 const port::ArraySlice<DeviceMemory<T> *> &b_ptrs_to_wrappers, int ldb, 159 T beta, const port::ArraySlice<DeviceMemory<T> *> &c_ptrs_to_wrappers, 160 int ldc, int batch_count, ScratchAllocator *scratch_allocator); 161 162 // Helper function for implementing DoBlasGemmWithAlgorithm. 163 // 164 // We take alpha and beta by const reference because T might be Eigen::half, 165 // and we want to avoid pulling in a dependency on Eigen. When we pass the 166 // references to rocBLAS, we essentially reinterpret_cast to __half, which is 167 // safe because Eigen::half inherits from __half. 168 template <typename InT, typename OutT, typename CompT> 169 bool DoBlasGemmWithAlgorithmImpl( 170 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 171 uint64 n, uint64 k, const CompT &alpha, const DeviceMemory<InT> &a, 172 int lda, const DeviceMemory<InT> &b, int ldb, const CompT &beta, 173 DeviceMemory<OutT> *c, int ldc, blas::ComputationType computation_type, 174 blas::AlgorithmType algorithm, 175 blas::ProfileResult *output_profile_result); 176 177 // Helper function for implementing DoBlasGemmWithProfiling. 178 template <typename T, typename ParamType> 179 bool DoBlasGemmWithProfilingImpl( 180 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 181 uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a, 182 int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta, 183 DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result); 184 185 // Helper function for implementing DoBlasGemvWithProfiling. 186 template <typename T> 187 bool DoBlasGemvWithProfilingImpl(Stream *stream, blas::Transpose trans, 188 uint64 m, uint64 n, const T &alpha, 189 const DeviceMemory<T> &a, int lda, 190 const DeviceMemory<T> &x, int incx, 191 const T &beta, DeviceMemory<T> *y, int incy, 192 blas::ProfileResult *output_profile_result); 193 194 // mutex that guards the rocBLAS handle for this device. 195 absl::Mutex mu_; 196 197 // GpuExecutor which instantiated this ROCMBlas. 198 // Immutable post-initialization. 199 GpuExecutor *parent_; 200 201 // rocBLAS library handle on the device. 202 rocblas_handle blas_ TF_GUARDED_BY(mu_); 203 204 SE_DISALLOW_COPY_AND_ASSIGN(ROCMBlas); 205 }; 206 207 } // namespace gpu 208 } // namespace stream_executor 209 210 #endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_ 211