1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ============================================================================== 15 */ 16 17 #ifndef TENSORFLOW_CORE_KERNELS_LINALG_ROCM_SOLVERS_H_ 18 #define TENSORFLOW_CORE_KERNELS_LINALG_ROCM_SOLVERS_H_ 19 20 // This header declares the class ROCmSolver, which contains wrappers of linear 21 // algebra solvers in the cuBlas and cuSolverDN libraries for use in TensorFlow 22 // kernels. 23 24 #if TENSORFLOW_USE_ROCM 25 26 #include <functional> 27 #include <vector> 28 29 #include "rocm/include/hip/hip_complex.h" 30 #include "rocm/include/rocblas.h" 31 #include "tensorflow/core/framework/op_kernel.h" 32 #include "tensorflow/core/framework/tensor.h" 33 #include "tensorflow/core/framework/tensor_reference.h" 34 #include "tensorflow/core/lib/core/status.h" 35 #include "tensorflow/core/platform/stream_executor.h" 36 #include "tensorflow/stream_executor/blas.h" 37 38 namespace tensorflow { 39 40 // Type traits to get ROCm complex types from std::complex<T>. 41 template <typename T> 42 struct ROCmComplexT { 43 typedef T type; 44 }; 45 template <> 46 struct ROCmComplexT<std::complex<float>> { 47 typedef hipComplex type; 48 }; 49 template <> 50 struct ROCmComplexT<std::complex<double>> { 51 typedef hipDoubleComplex type; 52 }; 53 // Converts pointers of std::complex<> to pointers of 54 // cuComplex/cuDoubleComplex. No type conversion for non-complex types. 55 template <typename T> 56 inline const typename ROCmComplexT<T>::type* ROCmComplex(const T* p) { 57 return reinterpret_cast<const typename ROCmComplexT<T>::type*>(p); 58 } 59 template <typename T> 60 inline typename ROCmComplexT<T>::type* ROCmComplex(T* p) { 61 return reinterpret_cast<typename ROCmComplexT<T>::type*>(p); 62 } 63 64 template <typename Scalar> 65 class ScratchSpace; 66 67 class ROCmSolver { 68 public: 69 // This object stores a pointer to context, which must outlive it. 70 explicit ROCmSolver(OpKernelContext* context); 71 virtual ~ROCmSolver(); 72 73 // Allocates a temporary tensor that will live for the duration of the 74 // ROCmSolver object. 75 Status allocate_scoped_tensor(DataType type, const TensorShape& shape, 76 Tensor* scoped_tensor); 77 Status forward_input_or_allocate_scoped_tensor( 78 gtl::ArraySlice<int> candidate_input_indices, DataType type, 79 const TensorShape& shape, Tensor* input_alias_or_new_scoped_tensor); 80 81 OpKernelContext* context() { return context_; } 82 83 template <typename Scalar> 84 Status Trsm(rocblas_side side, rocblas_fill uplo, rocblas_operation trans, 85 rocblas_diagonal diag, int m, int n, const Scalar* alpha, 86 const Scalar* A, int lda, Scalar* B, int ldb); 87 88 private: 89 OpKernelContext* context_; // not owned. 90 hipStream_t hip_stream_; 91 rocblas_handle rocm_blas_handle_; 92 std::vector<TensorReference> scratch_tensor_refs_; 93 94 TF_DISALLOW_COPY_AND_ASSIGN(ROCmSolver); 95 }; 96 97 // Helper class to allocate scratch memory and keep track of debug info. 98 // Mostly a thin wrapper around Tensor & allocate_temp. 99 template <typename Scalar> 100 class ScratchSpace { 101 public: 102 ScratchSpace(OpKernelContext* context, int64 size, bool on_host) 103 : ScratchSpace(context, TensorShape({size}), "", on_host) {} 104 105 ScratchSpace(OpKernelContext* context, int64 size, const string& debug_info, 106 bool on_host) 107 : ScratchSpace(context, TensorShape({size}), debug_info, on_host) {} 108 109 ScratchSpace(OpKernelContext* context, const TensorShape& shape, 110 const string& debug_info, bool on_host) 111 : context_(context), debug_info_(debug_info), on_host_(on_host) { 112 AllocatorAttributes alloc_attr; 113 if (on_host) { 114 // Allocate pinned memory on the host to avoid unnecessary 115 // synchronization. 116 alloc_attr.set_on_host(true); 117 alloc_attr.set_gpu_compatible(true); 118 } 119 TF_CHECK_OK(context->allocate_temp(DataTypeToEnum<Scalar>::value, shape, 120 &scratch_tensor_, alloc_attr)); 121 } 122 123 virtual ~ScratchSpace() {} 124 125 Scalar* mutable_data() { 126 return scratch_tensor_.template flat<Scalar>().data(); 127 } 128 const Scalar* data() const { 129 return scratch_tensor_.template flat<Scalar>().data(); 130 } 131 Scalar& operator()(int64 i) { 132 return scratch_tensor_.template flat<Scalar>()(i); 133 } 134 const Scalar& operator()(int64 i) const { 135 return scratch_tensor_.template flat<Scalar>()(i); 136 } 137 int64 bytes() const { return scratch_tensor_.TotalBytes(); } 138 int64 size() const { return scratch_tensor_.NumElements(); } 139 const string& debug_info() const { return debug_info_; } 140 141 Tensor& tensor() { return scratch_tensor_; } 142 const Tensor& tensor() const { return scratch_tensor_; } 143 144 // Returns true if this ScratchSpace is in host memory. 145 bool on_host() const { return on_host_; } 146 147 protected: 148 OpKernelContext* context() const { return context_; } 149 150 private: 151 OpKernelContext* context_; // not owned 152 const string debug_info_; 153 const bool on_host_; 154 Tensor scratch_tensor_; 155 }; 156 157 } // namespace tensorflow 158 159 #endif // TENSORFLOW_USE_ROCM 160 161 #endif // TENSORFLOW_CORE_KERNELS_LINALG_ROCM_SOLVERS_H_ 162