• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================
15 */
16 
17 #ifndef TENSORFLOW_CORE_KERNELS_LINALG_ROCM_SOLVERS_H_
18 #define TENSORFLOW_CORE_KERNELS_LINALG_ROCM_SOLVERS_H_
19 
20 // This header declares the class ROCmSolver, which contains wrappers of linear
21 // algebra solvers in the cuBlas and cuSolverDN libraries for use in TensorFlow
22 // kernels.
23 
24 #if TENSORFLOW_USE_ROCM
25 
26 #include <functional>
27 #include <vector>
28 
29 #include "rocm/include/hip/hip_complex.h"
30 #include "rocm/include/rocblas.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/tensor.h"
33 #include "tensorflow/core/framework/tensor_reference.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/platform/stream_executor.h"
36 #include "tensorflow/stream_executor/blas.h"
37 
38 namespace tensorflow {
39 
40 // Type traits to get ROCm complex types from std::complex<T>.
41 template <typename T>
42 struct ROCmComplexT {
43   typedef T type;
44 };
45 template <>
46 struct ROCmComplexT<std::complex<float>> {
47   typedef hipComplex type;
48 };
49 template <>
50 struct ROCmComplexT<std::complex<double>> {
51   typedef hipDoubleComplex type;
52 };
53 // Converts pointers of std::complex<> to pointers of
54 // cuComplex/cuDoubleComplex. No type conversion for non-complex types.
55 template <typename T>
56 inline const typename ROCmComplexT<T>::type* ROCmComplex(const T* p) {
57   return reinterpret_cast<const typename ROCmComplexT<T>::type*>(p);
58 }
59 template <typename T>
60 inline typename ROCmComplexT<T>::type* ROCmComplex(T* p) {
61   return reinterpret_cast<typename ROCmComplexT<T>::type*>(p);
62 }
63 
64 template <typename Scalar>
65 class ScratchSpace;
66 
67 class ROCmSolver {
68  public:
69   // This object stores a pointer to context, which must outlive it.
70   explicit ROCmSolver(OpKernelContext* context);
71   virtual ~ROCmSolver();
72 
73   // Allocates a temporary tensor that will live for the duration of the
74   // ROCmSolver object.
75   Status allocate_scoped_tensor(DataType type, const TensorShape& shape,
76                                 Tensor* scoped_tensor);
77   Status forward_input_or_allocate_scoped_tensor(
78       gtl::ArraySlice<int> candidate_input_indices, DataType type,
79       const TensorShape& shape, Tensor* input_alias_or_new_scoped_tensor);
80 
81   OpKernelContext* context() { return context_; }
82 
83   template <typename Scalar>
84   Status Trsm(rocblas_side side, rocblas_fill uplo, rocblas_operation trans,
85               rocblas_diagonal diag, int m, int n, const Scalar* alpha,
86               const Scalar* A, int lda, Scalar* B, int ldb);
87 
88  private:
89   OpKernelContext* context_;  // not owned.
90   hipStream_t hip_stream_;
91   rocblas_handle rocm_blas_handle_;
92   std::vector<TensorReference> scratch_tensor_refs_;
93 
94   TF_DISALLOW_COPY_AND_ASSIGN(ROCmSolver);
95 };
96 
97 // Helper class to allocate scratch memory and keep track of debug info.
98 // Mostly a thin wrapper around Tensor & allocate_temp.
99 template <typename Scalar>
100 class ScratchSpace {
101  public:
102   ScratchSpace(OpKernelContext* context, int64 size, bool on_host)
103       : ScratchSpace(context, TensorShape({size}), "", on_host) {}
104 
105   ScratchSpace(OpKernelContext* context, int64 size, const string& debug_info,
106                bool on_host)
107       : ScratchSpace(context, TensorShape({size}), debug_info, on_host) {}
108 
109   ScratchSpace(OpKernelContext* context, const TensorShape& shape,
110                const string& debug_info, bool on_host)
111       : context_(context), debug_info_(debug_info), on_host_(on_host) {
112     AllocatorAttributes alloc_attr;
113     if (on_host) {
114       // Allocate pinned memory on the host to avoid unnecessary
115       // synchronization.
116       alloc_attr.set_on_host(true);
117       alloc_attr.set_gpu_compatible(true);
118     }
119     TF_CHECK_OK(context->allocate_temp(DataTypeToEnum<Scalar>::value, shape,
120                                        &scratch_tensor_, alloc_attr));
121   }
122 
123   virtual ~ScratchSpace() {}
124 
125   Scalar* mutable_data() {
126     return scratch_tensor_.template flat<Scalar>().data();
127   }
128   const Scalar* data() const {
129     return scratch_tensor_.template flat<Scalar>().data();
130   }
131   Scalar& operator()(int64 i) {
132     return scratch_tensor_.template flat<Scalar>()(i);
133   }
134   const Scalar& operator()(int64 i) const {
135     return scratch_tensor_.template flat<Scalar>()(i);
136   }
137   int64 bytes() const { return scratch_tensor_.TotalBytes(); }
138   int64 size() const { return scratch_tensor_.NumElements(); }
139   const string& debug_info() const { return debug_info_; }
140 
141   Tensor& tensor() { return scratch_tensor_; }
142   const Tensor& tensor() const { return scratch_tensor_; }
143 
144   // Returns true if this ScratchSpace is in host memory.
145   bool on_host() const { return on_host_; }
146 
147  protected:
148   OpKernelContext* context() const { return context_; }
149 
150  private:
151   OpKernelContext* context_;  // not owned
152   const string debug_info_;
153   const bool on_host_;
154   Tensor scratch_tensor_;
155 };
156 
157 }  // namespace tensorflow
158 
159 #endif  // TENSORFLOW_USE_ROCM
160 
161 #endif  // TENSORFLOW_CORE_KERNELS_LINALG_ROCM_SOLVERS_H_
162