• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================
15 */
16 
17 #ifndef TENSORFLOW_CORE_KERNELS_ROCM_SOLVERS_H_
18 #define TENSORFLOW_CORE_KERNELS_ROCM_SOLVERS_H_
19 
20 // This header declares the class ROCmSolver, which contains wrappers of linear
21 // algebra solvers in the cuBlas and cuSolverDN libraries for use in TensorFlow
22 // kernels.
23 
24 #if TENSORFLOW_USE_ROCM
25 
26 #include <functional>
27 #include <vector>
28 
29 #include "rocm/include/hip/hip_complex.h"
30 #include "rocm/include/rocblas.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/tensor.h"
33 #include "tensorflow/core/lib/core/status.h"
34 #include "tensorflow/core/platform/stream_executor.h"
35 #include "tensorflow/stream_executor/blas.h"
36 
37 namespace tensorflow {
38 
39 // Type traits to get ROCm complex types from std::complex<T>.
40 template <typename T>
41 struct ROCmComplexT {
42   typedef T type;
43 };
44 template <>
45 struct ROCmComplexT<std::complex<float>> {
46   typedef hipComplex type;
47 };
48 template <>
49 struct ROCmComplexT<std::complex<double>> {
50   typedef hipDoubleComplex type;
51 };
52 // Converts pointers of std::complex<> to pointers of
53 // cuComplex/cuDoubleComplex. No type conversion for non-complex types.
54 template <typename T>
55 inline const typename ROCmComplexT<T>::type* ROCmComplex(const T* p) {
56   return reinterpret_cast<const typename ROCmComplexT<T>::type*>(p);
57 }
58 template <typename T>
59 inline typename ROCmComplexT<T>::type* ROCmComplex(T* p) {
60   return reinterpret_cast<typename ROCmComplexT<T>::type*>(p);
61 }
62 
63 template <typename Scalar>
64 class ScratchSpace;
65 
66 class ROCmSolver {
67  public:
68   // This object stores a pointer to context, which must outlive it.
69   explicit ROCmSolver(OpKernelContext* context);
70   virtual ~ROCmSolver();
71 
72   // Allocates a temporary tensor that will live for the duration of the
73   // ROCmSolver object.
74   Status allocate_scoped_tensor(DataType type, const TensorShape& shape,
75                                 Tensor* scoped_tensor);
76   Status forward_input_or_allocate_scoped_tensor(
77       gtl::ArraySlice<int> candidate_input_indices, DataType type,
78       const TensorShape& shape, Tensor* input_alias_or_new_scoped_tensor);
79 
80   OpKernelContext* context() { return context_; }
81 
82   template <typename Scalar>
83   Status Trsm(rocblas_side side, rocblas_fill uplo, rocblas_operation trans,
84               rocblas_diagonal diag, int m, int n, const Scalar* alpha,
85               const Scalar* A, int lda, Scalar* B, int ldb);
86 
87  private:
88   OpKernelContext* context_;  // not owned.
89   hipStream_t hip_stream_;
90   rocblas_handle rocm_blas_handle_;
91   std::vector<TensorReference> scratch_tensor_refs_;
92 
93   TF_DISALLOW_COPY_AND_ASSIGN(ROCmSolver);
94 };
95 
96 // Helper class to allocate scratch memory and keep track of debug info.
97 // Mostly a thin wrapper around Tensor & allocate_temp.
98 template <typename Scalar>
99 class ScratchSpace {
100  public:
101   ScratchSpace(OpKernelContext* context, int64 size, bool on_host)
102       : ScratchSpace(context, TensorShape({size}), "", on_host) {}
103 
104   ScratchSpace(OpKernelContext* context, int64 size, const string& debug_info,
105                bool on_host)
106       : ScratchSpace(context, TensorShape({size}), debug_info, on_host) {}
107 
108   ScratchSpace(OpKernelContext* context, const TensorShape& shape,
109                const string& debug_info, bool on_host)
110       : context_(context), debug_info_(debug_info), on_host_(on_host) {
111     AllocatorAttributes alloc_attr;
112     if (on_host) {
113       // Allocate pinned memory on the host to avoid unnecessary
114       // synchronization.
115       alloc_attr.set_on_host(true);
116       alloc_attr.set_gpu_compatible(true);
117     }
118     TF_CHECK_OK(context->allocate_temp(DataTypeToEnum<Scalar>::value, shape,
119                                        &scratch_tensor_, alloc_attr));
120   }
121 
122   virtual ~ScratchSpace() {}
123 
124   Scalar* mutable_data() {
125     return scratch_tensor_.template flat<Scalar>().data();
126   }
127   const Scalar* data() const {
128     return scratch_tensor_.template flat<Scalar>().data();
129   }
130   Scalar& operator()(int64 i) {
131     return scratch_tensor_.template flat<Scalar>()(i);
132   }
133   const Scalar& operator()(int64 i) const {
134     return scratch_tensor_.template flat<Scalar>()(i);
135   }
136   int64 bytes() const { return scratch_tensor_.TotalBytes(); }
137   int64 size() const { return scratch_tensor_.NumElements(); }
138   const string& debug_info() const { return debug_info_; }
139 
140   Tensor& tensor() { return scratch_tensor_; }
141   const Tensor& tensor() const { return scratch_tensor_; }
142 
143   // Returns true if this ScratchSpace is in host memory.
144   bool on_host() const { return on_host_; }
145 
146  protected:
147   OpKernelContext* context() const { return context_; }
148 
149  private:
150   OpKernelContext* context_;  // not owned
151   const string debug_info_;
152   const bool on_host_;
153   Tensor scratch_tensor_;
154 };
155 
156 }  // namespace tensorflow
157 
158 #endif  // TENSORFLOW_USE_ROCM
159 
160 #endif  // TENSORFLOW_CORE_KERNELS_ROCM_SOLVERS_H_
161