• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================
15 */
16 
17 #ifndef TENSORFLOW_CORE_KERNELS_CUDA_SOLVERS_H_
18 #define TENSORFLOW_CORE_KERNELS_CUDA_SOLVERS_H_
19 
20 // This header declares the class CudaSolver, which contains wrappers of linear
21 // algebra solvers in the cuBlas and cuSolverDN libraries for use in TensorFlow
22 // kernels.
23 
24 #ifdef GOOGLE_CUDA
25 
26 #include <functional>
27 #include <vector>
28 
29 #include "cuda/include/cublas_v2.h"
30 #include "cuda/include/cusolverDn.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/tensor.h"
33 #include "tensorflow/core/lib/core/status.h"
34 #include "tensorflow/core/platform/stream_executor.h"
35 
36 namespace tensorflow {
37 
38 // Type traits to get CUDA complex types from std::complex<T>.
39 template <typename T>
40 struct CUDAComplexT {
41   typedef T type;
42 };
43 template <>
44 struct CUDAComplexT<std::complex<float>> {
45   typedef cuComplex type;
46 };
47 template <>
48 struct CUDAComplexT<std::complex<double>> {
49   typedef cuDoubleComplex type;
50 };
51 // Converts pointers of std::complex<> to pointers of
52 // cuComplex/cuDoubleComplex. No type conversion for non-complex types.
53 template <typename T>
54 inline const typename CUDAComplexT<T>::type* CUDAComplex(const T* p) {
55   return reinterpret_cast<const typename CUDAComplexT<T>::type*>(p);
56 }
57 template <typename T>
58 inline typename CUDAComplexT<T>::type* CUDAComplex(T* p) {
59   return reinterpret_cast<typename CUDAComplexT<T>::type*>(p);
60 }
61 
62 // Template to give the Cublas adjoint operation for real and complex types.
63 template <typename T>
64 cublasOperation_t CublasAdjointOp() {
65   return Eigen::NumTraits<T>::IsComplex ? CUBLAS_OP_C : CUBLAS_OP_T;
66 }
67 
68 // Container of LAPACK info data (an array of int) generated on-device by
69 // a CudaSolver call. One or more such objects can be passed to
70 // CudaSolver::CopyLapackInfoToHostAsync() along with a callback to
71 // check the LAPACK info data after the corresponding kernels
72 // finish and LAPACK info has been copied from the device to the host.
73 class DeviceLapackInfo;
74 
75 // Host-side copy of LAPACK info.
76 class HostLapackInfo;
77 
78 // The CudaSolver class provides a simplified templated API for the dense linear
79 // solvers implemented in cuSolverDN (http://docs.nvidia.com/cuda/cusolver) and
80 // cuBlas (http://docs.nvidia.com/cuda/cublas/#blas-like-extension/).
81 // An object of this class wraps static cuSolver and cuBlas instances,
82 // and will launch Cuda kernels on the stream wrapped by the GPU device
83 // in the OpKernelContext provided to the constructor.
84 //
85 // Notice: All the computational member functions are asynchronous and simply
86 // launch one or more Cuda kernels on the Cuda stream wrapped by the CudaSolver
87 // object. To check the final status of the kernels run, call
88 // CopyLapackInfoToHostAsync() on the CudaSolver object to set a callback that
89 // will be invoked with the status of the kernels launched thus far as
90 // arguments.
91 //
92 // Example of an asynchronous TensorFlow kernel using CudaSolver:
93 //
94 // template <typename Scalar>
95 // class SymmetricPositiveDefiniteSolveOpGpu : public AsyncOpKernel {
96 //  public:
97 //   explicit SymmetricPositiveDefiniteSolveOpGpu(OpKernelConstruction* context)
98 //       : AsyncOpKernel(context) { }
99 //   void ComputeAsync(OpKernelContext* context, DoneCallback done) final {
100 //     // 1. Set up input and output device ptrs. See, e.g.,
101 //     // matrix_inverse_op.cc for a full example.
102 //     ...
103 //
104 //     // 2. Initialize the solver object.
105 //     std::unique_ptr<CudaSolver> solver(new CudaSolver(context));
106 //
107 //     // 3. Launch the two compute kernels back to back on the stream without
108 //     // synchronizing.
109 //     std::vector<DeviceLapackInfo> dev_info;
110 //     const int batch_size = 1;
111 //     dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "potrf");
112 //     // Compute the Cholesky decomposition of the input matrix.
113 //     OP_REQUIRES_OK_ASYNC(context,
114 //                          solver->Potrf(uplo, n, dev_matrix_ptrs, n,
115 //                                        dev_info.back().mutable_data()),
116 //                          done);
117 //     dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "potrs");
118 //     // Use the Cholesky decomposition of the input matrix to solve A X = RHS.
119 //     OP_REQUIRES_OK_ASYNC(context,
120 //                          solver->Potrs(uplo, n, nrhs, dev_matrix_ptrs, n,
121 //                                        dev_output_ptrs, ldrhs,
122 //                                        dev_info.back().mutable_data()),
123 //                          done);
124 //
125 //     // 4. Check the status after the computation finishes and call done.
126 //     solver.CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info,
127 //                                                std::move(done));
128 //   }
129 // };
130 
131 template <typename Scalar>
132 class ScratchSpace;
133 
134 class CudaSolver {
135  public:
136   // This object stores a pointer to context, which must outlive it.
137   explicit CudaSolver(OpKernelContext* context);
138   virtual ~CudaSolver();
139 
140   // Launches a memcpy of solver status data specified by dev_lapack_info from
141   // device to the host, and asynchronously invokes the given callback when the
142   // copy is complete. The first Status argument to the callback will be
143   // Status::OK if all lapack infos retrieved are zero, otherwise an error
144   // status is given. The second argument contains a host-side copy of the
145   // entire set of infos retrieved, and can be used for generating detailed
146   // error messages.
147   // `info_checker_callback` must call the DoneCallback of any asynchronous
148   // OpKernel within which `solver` is used.
149   static void CheckLapackInfoAndDeleteSolverAsync(
150       std::unique_ptr<CudaSolver> solver,
151       const std::vector<DeviceLapackInfo>& dev_lapack_info,
152       std::function<void(const Status&, const std::vector<HostLapackInfo>&)>
153           info_checker_callback);
154 
155   // Simpler version to use if no special error checking / messages are needed
156   // apart from checking that the Status of all calls was Status::OK.
157   // `done` may be nullptr.
158   static void CheckLapackInfoAndDeleteSolverAsync(
159       std::unique_ptr<CudaSolver> solver,
160       const std::vector<DeviceLapackInfo>& dev_lapack_info,
161       AsyncOpKernel::DoneCallback done);
162 
163   // Returns a ScratchSpace. The CudaSolver object maintains a TensorReference
164   // to the underlying Tensor to prevent it from being deallocated prematurely.
165   template <typename Scalar>
166   ScratchSpace<Scalar> GetScratchSpace(const TensorShape& shape,
167                                        const string& debug_info, bool on_host);
168   template <typename Scalar>
169   ScratchSpace<Scalar> GetScratchSpace(int64 size, const string& debug_info,
170                                        bool on_host);
171   // Returns a DeviceLapackInfo that will live for the duration of the
172   // CudaSolver object.
173   inline DeviceLapackInfo GetDeviceLapackInfo(int64 size,
174                                               const string& debug_info);
175 
176   // Allocates a temporary tensor that will live for the duration of the
177   // CudaSolver object.
178   Status allocate_scoped_tensor(DataType type, const TensorShape& shape,
179                                 Tensor* scoped_tensor);
180   Status forward_input_or_allocate_scoped_tensor(
181       gtl::ArraySlice<int> candidate_input_indices, DataType type,
182       const TensorShape& shape, Tensor* input_alias_or_new_scoped_tensor);
183 
184   OpKernelContext* context() { return context_; }
185 
186   // ====================================================================
187   // Wrappers for cuSolverDN and cuBlas solvers start here.
188   //
189   // Apart from capitalization of the first letter, the method names below
190   // map to those in cuSolverDN and cuBlas, which follow the naming
191   // convention in LAPACK see, e.g.,
192   // http://docs.nvidia.com/cuda/cusolver/#naming-convention
193 
194   // This function performs the matrix-matrix addition/transposition
195   //   C = alpha * op(A) + beta * op(B).
196   // Returns Status::OK() if the kernel was launched successfully.  See:
197   // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-geam
198   // NOTE(ebrevdo): Does not support in-place transpose of non-square
199   // matrices.
200   template <typename Scalar>
201   Status Geam(cublasOperation_t transa, cublasOperation_t transb, int m, int n,
202               const Scalar* alpha, /* host or device pointer */
203               const Scalar* A, int lda,
204               const Scalar* beta, /* host or device pointer */
205               const Scalar* B, int ldb, Scalar* C,
206               int ldc) const TF_MUST_USE_RESULT;
207 
208   // Computes the Cholesky factorization A = L * L^T for a single matrix.
209   // Returns Status::OK() if the kernel was launched successfully. See:
210   // http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-potrf
211   template <typename Scalar>
212   Status Potrf(cublasFillMode_t uplo, int n, Scalar* dev_A, int lda,
213                int* dev_lapack_info) TF_MUST_USE_RESULT;
214 
215   // LU factorization.
216   // Computes LU factorization with partial pivoting P * A = L * U.
217   // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-getrf
218   template <typename Scalar>
219   Status Getrf(int m, int n, Scalar* dev_A, int lda, int* dev_pivots,
220                int* dev_lapack_info) TF_MUST_USE_RESULT;
221 
222   // Uses LU factorization to solve A * X = B.
223   // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-getrs
224   template <typename Scalar>
225   Status Getrs(cublasOperation_t trans, int n, int nrhs, const Scalar* A,
226                int lda, const int* pivots, Scalar* B, int ldb,
227                int* dev_lapack_info) const TF_MUST_USE_RESULT;
228 
229   // Computes partially pivoted LU factorizations for a batch of small matrices.
230   // Returns Status::OK() if the kernel was launched successfully.See:
231   // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrfbatched
232   template <typename Scalar>
233   Status GetrfBatched(int n, const Scalar* const host_a_dev_ptrs[], int lda,
234                       int* dev_pivots, DeviceLapackInfo* dev_lapack_info,
235                       int batch_size) TF_MUST_USE_RESULT;
236 
237   // Batched linear solver using LU factorization from getrfBatched.
238   // Notice that lapack_info is returned on the host, as opposed to
239   // most of the other functions that return it on the device. See:
240   // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrsbatched
241   template <typename Scalar>
242   Status GetrsBatched(cublasOperation_t trans, int n, int nrhs,
243                       const Scalar* const dev_Aarray[], int lda,
244                       const int* devIpiv, const Scalar* const dev_Barray[],
245                       int ldb, int* host_lapack_info,
246                       int batch_size) TF_MUST_USE_RESULT;
247 
248   // Computes matrix inverses for a batch of small matrices. Uses the outputs
249   // from GetrfBatched. Returns Status::OK() if the kernel was launched
250   // successfully. See:
251   // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getribatched
252   template <typename Scalar>
253   Status GetriBatched(int n, const Scalar* const host_a_dev_ptrs[], int lda,
254                       const int* dev_pivots,
255                       const Scalar* const host_a_inverse_dev_ptrs[], int ldainv,
256                       DeviceLapackInfo* dev_lapack_info,
257                       int batch_size) TF_MUST_USE_RESULT;
258 
259   // Computes matrix inverses for a batch of small matrices with size n < 32.
260   // Returns Status::OK() if the kernel was launched successfully. See:
261   // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-matinvbatched
262   template <typename Scalar>
263   Status MatInvBatched(int n, const Scalar* const host_a_dev_ptrs[], int lda,
264                        const Scalar* const host_a_inverse_dev_ptrs[],
265                        int ldainv, DeviceLapackInfo* dev_lapack_info,
266                        int batch_size) TF_MUST_USE_RESULT;
267 
268   // QR factorization.
269   // Computes QR factorization A = Q * R.
270   // Returns Status::OK() if the kernel was launched successfully.
271   // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-geqrf
272   template <typename Scalar>
273   Status Geqrf(int m, int n, Scalar* dev_A, int lda, Scalar* dev_tau,
274                int* dev_lapack_info) TF_MUST_USE_RESULT;
275 
276   // Overwrite matrix C by product of C and the unitary Householder matrix Q.
277   // The Householder matrix Q is represented by the output from Geqrf in dev_a
278   // and dev_tau.
279   // Notice: If Scalar is real, only trans=CUBLAS_OP_N or trans=CUBLAS_OP_T is
280   // supported. If Scalar is complex, trans=CUBLAS_OP_N or trans=CUBLAS_OP_C is
281   // supported.
282   // Returns Status::OK() if the kernel was launched successfully.
283   // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-ormqr
284   template <typename Scalar>
285   Status Unmqr(cublasSideMode_t side, cublasOperation_t trans, int m, int n,
286                int k, const Scalar* dev_a, int lda, const Scalar* dev_tau,
287                Scalar* dev_c, int ldc, int* dev_lapack_info) TF_MUST_USE_RESULT;
288 
289   // Overwrites QR factorization produced by Geqrf by the unitary Householder
290   // matrix Q. On input, the Householder matrix Q is represented by the output
291   // from Geqrf in dev_a and dev_tau. On output, dev_a is overwritten with the
292   // first n columns of Q. Requires m >= n >= 0.
293   // Returns Status::OK() if the kernel was launched successfully.
294   // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-orgqr
295   template <typename Scalar>
296   Status Ungqr(int m, int n, int k, Scalar* dev_a, int lda,
297                const Scalar* dev_tau, int* dev_lapack_info) TF_MUST_USE_RESULT;
298 
299   // Hermitian (Symmetric) Eigen decomposition.
300   // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-syevd
301   template <typename Scalar>
302   Status Heevd(cusolverEigMode_t jobz, cublasFillMode_t uplo, int n,
303                Scalar* dev_A, int lda,
304                typename Eigen::NumTraits<Scalar>::Real* dev_W,
305                int* dev_lapack_info) TF_MUST_USE_RESULT;
306 
307   // Singular value decomposition.
308   // Returns Status::OK() if the kernel was launched successfully.
309   // TODO(rmlarsen, volunteers): Add support for complex types.
310   // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-gesvd
311   template <typename Scalar>
312   Status Gesvd(signed char jobu, signed char jobvt, int m, int n, Scalar* dev_A,
313                int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT,
314                int ldvt, int* dev_lapack_info) TF_MUST_USE_RESULT;
315   template <typename Scalar>
316   Status GesvdjBatched(cusolverEigMode_t jobz, int m, int n, Scalar* dev_A,
317                        int lda, Scalar* dev_S, Scalar* dev_U, int ldu,
318                        Scalar* dev_V, int ldv, int* dev_lapack_info,
319                        int batch_size);
320 
321  private:
322   OpKernelContext* context_;  // not owned.
323   cudaStream_t cuda_stream_;
324   cusolverDnHandle_t cusolver_dn_handle_;
325   cublasHandle_t cublas_handle_;
326   std::vector<TensorReference> scratch_tensor_refs_;
327 
328   TF_DISALLOW_COPY_AND_ASSIGN(CudaSolver);
329 };
330 
331 // Helper class to allocate scratch memory and keep track of debug info.
332 // Mostly a thin wrapper around Tensor & allocate_temp.
333 template <typename Scalar>
334 class ScratchSpace {
335  public:
336   ScratchSpace(OpKernelContext* context, int64 size, bool on_host)
337       : ScratchSpace(context, TensorShape({size}), "", on_host) {}
338 
339   ScratchSpace(OpKernelContext* context, int64 size, const string& debug_info,
340                bool on_host)
341       : ScratchSpace(context, TensorShape({size}), debug_info, on_host) {}
342 
343   ScratchSpace(OpKernelContext* context, const TensorShape& shape,
344                const string& debug_info, bool on_host)
345       : context_(context), debug_info_(debug_info), on_host_(on_host) {
346     AllocatorAttributes alloc_attr;
347     if (on_host) {
348       // Allocate pinned memory on the host to avoid unnecessary
349       // synchronization.
350       alloc_attr.set_on_host(true);
351       alloc_attr.set_gpu_compatible(true);
352     }
353     TF_CHECK_OK(context->allocate_temp(DataTypeToEnum<Scalar>::value, shape,
354                                        &scratch_tensor_, alloc_attr));
355   }
356 
357   virtual ~ScratchSpace() {}
358 
359   Scalar* mutable_data() {
360     return scratch_tensor_.template flat<Scalar>().data();
361   }
362   const Scalar* data() const {
363     return scratch_tensor_.template flat<Scalar>().data();
364   }
365   Scalar& operator()(int64 i) {
366     return scratch_tensor_.template flat<Scalar>()(i);
367   }
368   const Scalar& operator()(int64 i) const {
369     return scratch_tensor_.template flat<Scalar>()(i);
370   }
371   int64 bytes() const { return scratch_tensor_.TotalBytes(); }
372   int64 size() const { return scratch_tensor_.NumElements(); }
373   const string& debug_info() const { return debug_info_; }
374 
375   Tensor& tensor() { return scratch_tensor_; }
376   const Tensor& tensor() const { return scratch_tensor_; }
377 
378   // Returns true if this ScratchSpace is in host memory.
379   bool on_host() const { return on_host_; }
380 
381  protected:
382   OpKernelContext* context() const { return context_; }
383 
384  private:
385   OpKernelContext* context_;  // not owned
386   const string debug_info_;
387   const bool on_host_;
388   Tensor scratch_tensor_;
389 };
390 
391 class HostLapackInfo : public ScratchSpace<int> {
392  public:
393   HostLapackInfo(OpKernelContext* context, int64 size, const string& debug_info)
394       : ScratchSpace<int>(context, size, debug_info, /* on_host */ true){};
395 };
396 
397 class DeviceLapackInfo : public ScratchSpace<int> {
398  public:
399   DeviceLapackInfo(OpKernelContext* context, int64 size,
400                    const string& debug_info)
401       : ScratchSpace<int>(context, size, debug_info, /* on_host */ false) {}
402 
403   // Allocates a new scratch space on the host and launches a copy of the
404   // contents of *this to the new scratch space. Sets success to true if
405   // the copy kernel was launched successfully.
406   HostLapackInfo CopyToHost(bool* success) const {
407     CHECK(success != nullptr);
408     HostLapackInfo copy(context(), size(), debug_info());
409     auto stream = context()->op_device_context()->stream();
410     se::DeviceMemoryBase wrapped_src(
411         static_cast<void*>(const_cast<int*>(this->data())));
412     *success =
413         stream->ThenMemcpy(copy.mutable_data(), wrapped_src, this->bytes())
414             .ok();
415     return copy;
416   }
417 };
418 
419 template <typename Scalar>
420 ScratchSpace<Scalar> CudaSolver::GetScratchSpace(const TensorShape& shape,
421                                                  const string& debug_info,
422                                                  bool on_host) {
423   ScratchSpace<Scalar> new_scratch_space(context_, shape, debug_info, on_host);
424   scratch_tensor_refs_.emplace_back(new_scratch_space.tensor());
425   return std::move(new_scratch_space);
426 }
427 
428 template <typename Scalar>
429 ScratchSpace<Scalar> CudaSolver::GetScratchSpace(int64 size,
430                                                  const string& debug_info,
431                                                  bool on_host) {
432   return GetScratchSpace<Scalar>(TensorShape({size}), debug_info, on_host);
433 }
434 
435 inline DeviceLapackInfo CudaSolver::GetDeviceLapackInfo(
436     int64 size, const string& debug_info) {
437   DeviceLapackInfo new_dev_info(context_, size, debug_info);
438   scratch_tensor_refs_.emplace_back(new_dev_info.tensor());
439   return new_dev_info;
440 }
441 
442 }  // namespace tensorflow
443 
444 #endif  // GOOGLE_CUDA
445 
446 #endif  // TENSORFLOW_CORE_KERNELS_CUDA_SOLVERS_H_
447