1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ============================================================================== 15 */ 16 17 #ifndef TENSORFLOW_CORE_KERNELS_ROCM_SOLVERS_H_ 18 #define TENSORFLOW_CORE_KERNELS_ROCM_SOLVERS_H_ 19 20 // This header declares the class ROCmSolver, which contains wrappers of linear 21 // algebra solvers in the cuBlas and cuSolverDN libraries for use in TensorFlow 22 // kernels. 23 24 #if TENSORFLOW_USE_ROCM 25 26 #include <functional> 27 #include <vector> 28 29 #include "rocm/include/hip/hip_complex.h" 30 #include "rocm/include/rocblas.h" 31 #include "tensorflow/core/framework/op_kernel.h" 32 #include "tensorflow/core/framework/tensor.h" 33 #include "tensorflow/core/lib/core/status.h" 34 #include "tensorflow/core/platform/stream_executor.h" 35 #include "tensorflow/stream_executor/blas.h" 36 37 namespace tensorflow { 38 39 // Type traits to get ROCm complex types from std::complex<T>. 40 template <typename T> 41 struct ROCmComplexT { 42 typedef T type; 43 }; 44 template <> 45 struct ROCmComplexT<std::complex<float>> { 46 typedef hipComplex type; 47 }; 48 template <> 49 struct ROCmComplexT<std::complex<double>> { 50 typedef hipDoubleComplex type; 51 }; 52 // Converts pointers of std::complex<> to pointers of 53 // cuComplex/cuDoubleComplex. No type conversion for non-complex types. 54 template <typename T> 55 inline const typename ROCmComplexT<T>::type* ROCmComplex(const T* p) { 56 return reinterpret_cast<const typename ROCmComplexT<T>::type*>(p); 57 } 58 template <typename T> 59 inline typename ROCmComplexT<T>::type* ROCmComplex(T* p) { 60 return reinterpret_cast<typename ROCmComplexT<T>::type*>(p); 61 } 62 63 template <typename Scalar> 64 class ScratchSpace; 65 66 class ROCmSolver { 67 public: 68 // This object stores a pointer to context, which must outlive it. 69 explicit ROCmSolver(OpKernelContext* context); 70 virtual ~ROCmSolver(); 71 72 // Allocates a temporary tensor that will live for the duration of the 73 // ROCmSolver object. 74 Status allocate_scoped_tensor(DataType type, const TensorShape& shape, 75 Tensor* scoped_tensor); 76 Status forward_input_or_allocate_scoped_tensor( 77 gtl::ArraySlice<int> candidate_input_indices, DataType type, 78 const TensorShape& shape, Tensor* input_alias_or_new_scoped_tensor); 79 80 OpKernelContext* context() { return context_; } 81 82 template <typename Scalar> 83 Status Trsm(rocblas_side side, rocblas_fill uplo, rocblas_operation trans, 84 rocblas_diagonal diag, int m, int n, const Scalar* alpha, 85 const Scalar* A, int lda, Scalar* B, int ldb); 86 87 private: 88 OpKernelContext* context_; // not owned. 89 hipStream_t hip_stream_; 90 rocblas_handle rocm_blas_handle_; 91 std::vector<TensorReference> scratch_tensor_refs_; 92 93 TF_DISALLOW_COPY_AND_ASSIGN(ROCmSolver); 94 }; 95 96 // Helper class to allocate scratch memory and keep track of debug info. 97 // Mostly a thin wrapper around Tensor & allocate_temp. 98 template <typename Scalar> 99 class ScratchSpace { 100 public: 101 ScratchSpace(OpKernelContext* context, int64 size, bool on_host) 102 : ScratchSpace(context, TensorShape({size}), "", on_host) {} 103 104 ScratchSpace(OpKernelContext* context, int64 size, const string& debug_info, 105 bool on_host) 106 : ScratchSpace(context, TensorShape({size}), debug_info, on_host) {} 107 108 ScratchSpace(OpKernelContext* context, const TensorShape& shape, 109 const string& debug_info, bool on_host) 110 : context_(context), debug_info_(debug_info), on_host_(on_host) { 111 AllocatorAttributes alloc_attr; 112 if (on_host) { 113 // Allocate pinned memory on the host to avoid unnecessary 114 // synchronization. 115 alloc_attr.set_on_host(true); 116 alloc_attr.set_gpu_compatible(true); 117 } 118 TF_CHECK_OK(context->allocate_temp(DataTypeToEnum<Scalar>::value, shape, 119 &scratch_tensor_, alloc_attr)); 120 } 121 122 virtual ~ScratchSpace() {} 123 124 Scalar* mutable_data() { 125 return scratch_tensor_.template flat<Scalar>().data(); 126 } 127 const Scalar* data() const { 128 return scratch_tensor_.template flat<Scalar>().data(); 129 } 130 Scalar& operator()(int64 i) { 131 return scratch_tensor_.template flat<Scalar>()(i); 132 } 133 const Scalar& operator()(int64 i) const { 134 return scratch_tensor_.template flat<Scalar>()(i); 135 } 136 int64 bytes() const { return scratch_tensor_.TotalBytes(); } 137 int64 size() const { return scratch_tensor_.NumElements(); } 138 const string& debug_info() const { return debug_info_; } 139 140 Tensor& tensor() { return scratch_tensor_; } 141 const Tensor& tensor() const { return scratch_tensor_; } 142 143 // Returns true if this ScratchSpace is in host memory. 144 bool on_host() const { return on_host_; } 145 146 protected: 147 OpKernelContext* context() const { return context_; } 148 149 private: 150 OpKernelContext* context_; // not owned 151 const string debug_info_; 152 const bool on_host_; 153 Tensor scratch_tensor_; 154 }; 155 156 } // namespace tensorflow 157 158 #endif // TENSORFLOW_USE_ROCM 159 160 #endif // TENSORFLOW_CORE_KERNELS_ROCM_SOLVERS_H_ 161