• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================
15 */
16 
17 #ifndef TENSORFLOW_CORE_KERNELS_LINALG_CUDA_SOLVERS_H_
18 #define TENSORFLOW_CORE_KERNELS_LINALG_CUDA_SOLVERS_H_
19 
20 // This header declares the class CudaSolver, which contains wrappers of linear
21 // algebra solvers in the cuBlas and cuSolverDN libraries for use in TensorFlow
22 // kernels.
23 
24 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
25 
26 #include <functional>
27 #include <vector>
28 
29 #if GOOGLE_CUDA
30 #include "third_party/gpus/cuda/include/cublas_v2.h"
31 #include "third_party/gpus/cuda/include/cuda.h"
32 #include "third_party/gpus/cuda/include/cusolverDn.h"
33 #endif
34 #include "tensorflow/core/framework/op_kernel.h"
35 #include "tensorflow/core/framework/tensor.h"
36 #include "tensorflow/core/framework/tensor_reference.h"
37 #include "tensorflow/core/lib/core/status.h"
38 #include "tensorflow/core/platform/stream_executor.h"
39 
40 namespace tensorflow {
41 
42 #if GOOGLE_CUDA
43 // Type traits to get CUDA complex types from std::complex<T>.
44 template <typename T>
45 struct CUDAComplexT {
46   typedef T type;
47 };
48 template <>
49 struct CUDAComplexT<std::complex<float>> {
50   typedef cuComplex type;
51 };
52 template <>
53 struct CUDAComplexT<std::complex<double>> {
54   typedef cuDoubleComplex type;
55 };
56 // Converts pointers of std::complex<> to pointers of
57 // cuComplex/cuDoubleComplex. No type conversion for non-complex types.
58 template <typename T>
59 inline const typename CUDAComplexT<T>::type* CUDAComplex(const T* p) {
60   return reinterpret_cast<const typename CUDAComplexT<T>::type*>(p);
61 }
62 template <typename T>
63 inline typename CUDAComplexT<T>::type* CUDAComplex(T* p) {
64   return reinterpret_cast<typename CUDAComplexT<T>::type*>(p);
65 }
66 
67 // Template to give the Cublas adjoint operation for real and complex types.
68 template <typename T>
69 cublasOperation_t CublasAdjointOp() {
70   return Eigen::NumTraits<T>::IsComplex ? CUBLAS_OP_C : CUBLAS_OP_T;
71 }
72 
73 // Container of LAPACK info data (an array of int) generated on-device by
74 // a CudaSolver call. One or more such objects can be passed to
75 // CudaSolver::CopyLapackInfoToHostAsync() along with a callback to
76 // check the LAPACK info data after the corresponding kernels
77 // finish and LAPACK info has been copied from the device to the host.
78 class DeviceLapackInfo;
79 
80 // Host-side copy of LAPACK info.
81 class HostLapackInfo;
82 
83 // The CudaSolver class provides a simplified templated API for the dense linear
84 // solvers implemented in cuSolverDN (http://docs.nvidia.com/cuda/cusolver) and
85 // cuBlas (http://docs.nvidia.com/cuda/cublas/#blas-like-extension/).
86 // An object of this class wraps static cuSolver and cuBlas instances,
87 // and will launch Cuda kernels on the stream wrapped by the GPU device
88 // in the OpKernelContext provided to the constructor.
89 //
90 // Notice: All the computational member functions are asynchronous and simply
91 // launch one or more Cuda kernels on the Cuda stream wrapped by the CudaSolver
92 // object. To check the final status of the kernels run, call
93 // CopyLapackInfoToHostAsync() on the CudaSolver object to set a callback that
94 // will be invoked with the status of the kernels launched thus far as
95 // arguments.
96 //
97 // Example of an asynchronous TensorFlow kernel using CudaSolver:
98 //
99 // template <typename Scalar>
100 // class SymmetricPositiveDefiniteSolveOpGpu : public AsyncOpKernel {
101 //  public:
102 //   explicit SymmetricPositiveDefiniteSolveOpGpu(OpKernelConstruction* context)
103 //       : AsyncOpKernel(context) { }
104 //   void ComputeAsync(OpKernelContext* context, DoneCallback done) final {
105 //     // 1. Set up input and output device ptrs. See, e.g.,
106 //     // matrix_inverse_op.cc for a full example.
107 //     ...
108 //
109 //     // 2. Initialize the solver object.
110 //     std::unique_ptr<CudaSolver> solver(new CudaSolver(context));
111 //
112 //     // 3. Launch the two compute kernels back to back on the stream without
113 //     // synchronizing.
114 //     std::vector<DeviceLapackInfo> dev_info;
115 //     const int batch_size = 1;
116 //     dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "potrf");
117 //     // Compute the Cholesky decomposition of the input matrix.
118 //     OP_REQUIRES_OK_ASYNC(context,
119 //                          solver->Potrf(uplo, n, dev_matrix_ptrs, n,
120 //                                        dev_info.back().mutable_data()),
121 //                          done);
122 //     dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "potrs");
123 //     // Use the Cholesky decomposition of the input matrix to solve A X = RHS.
124 //     OP_REQUIRES_OK_ASYNC(context,
125 //                          solver->Potrs(uplo, n, nrhs, dev_matrix_ptrs, n,
126 //                                        dev_output_ptrs, ldrhs,
127 //                                        dev_info.back().mutable_data()),
128 //                          done);
129 //
130 //     // 4. Check the status after the computation finishes and call done.
131 //     solver.CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info,
132 //                                                std::move(done));
133 //   }
134 // };
135 
136 template <typename Scalar>
137 class ScratchSpace;
138 
139 class CudaSolver {
140  public:
141   // This object stores a pointer to context, which must outlive it.
142   explicit CudaSolver(OpKernelContext* context);
143   virtual ~CudaSolver();
144 
145   // Launches a memcpy of solver status data specified by dev_lapack_info from
146   // device to the host, and asynchronously invokes the given callback when the
147   // copy is complete. The first Status argument to the callback will be
148   // Status::OK if all lapack infos retrieved are zero, otherwise an error
149   // status is given. The second argument contains a host-side copy of the
150   // entire set of infos retrieved, and can be used for generating detailed
151   // error messages.
152   // `info_checker_callback` must call the DoneCallback of any asynchronous
153   // OpKernel within which `solver` is used.
154   static void CheckLapackInfoAndDeleteSolverAsync(
155       std::unique_ptr<CudaSolver> solver,
156       const std::vector<DeviceLapackInfo>& dev_lapack_info,
157       std::function<void(const Status&, const std::vector<HostLapackInfo>&)>
158           info_checker_callback);
159 
160   // Simpler version to use if no special error checking / messages are needed
161   // apart from checking that the Status of all calls was Status::OK.
162   // `done` may be nullptr.
163   static void CheckLapackInfoAndDeleteSolverAsync(
164       std::unique_ptr<CudaSolver> solver,
165       const std::vector<DeviceLapackInfo>& dev_lapack_info,
166       AsyncOpKernel::DoneCallback done);
167 
168   // Returns a ScratchSpace. The CudaSolver object maintains a TensorReference
169   // to the underlying Tensor to prevent it from being deallocated prematurely.
170   template <typename Scalar>
171   ScratchSpace<Scalar> GetScratchSpace(const TensorShape& shape,
172                                        const std::string& debug_info,
173                                        bool on_host);
174   template <typename Scalar>
175   ScratchSpace<Scalar> GetScratchSpace(int64 size,
176                                        const std::string& debug_info,
177                                        bool on_host);
178   // Returns a DeviceLapackInfo that will live for the duration of the
179   // CudaSolver object.
180   inline DeviceLapackInfo GetDeviceLapackInfo(int64 size,
181                                               const std::string& debug_info);
182 
183   // Allocates a temporary tensor that will live for the duration of the
184   // CudaSolver object.
185   Status allocate_scoped_tensor(DataType type, const TensorShape& shape,
186                                 Tensor* scoped_tensor);
187   Status forward_input_or_allocate_scoped_tensor(
188       gtl::ArraySlice<int> candidate_input_indices, DataType type,
189       const TensorShape& shape, Tensor* input_alias_or_new_scoped_tensor);
190 
191   OpKernelContext* context() { return context_; }
192 
193   // ====================================================================
194   // Wrappers for cuSolverDN and cuBlas solvers start here.
195   //
196   // Apart from capitalization of the first letter, the method names below
197   // map to those in cuSolverDN and cuBlas, which follow the naming
198   // convention in LAPACK see, e.g.,
199   // http://docs.nvidia.com/cuda/cusolver/#naming-convention
200 
201   // This function performs the matrix-matrix addition/transposition
202   //   C = alpha * op(A) + beta * op(B).
203   // Returns Status::OK() if the kernel was launched successfully.  See:
204   // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-geam
205   // NOTE(ebrevdo): Does not support in-place transpose of non-square
206   // matrices.
207   template <typename Scalar>
208   Status Geam(cublasOperation_t transa, cublasOperation_t transb, int m, int n,
209               const Scalar* alpha, /* host or device pointer */
210               const Scalar* A, int lda,
211               const Scalar* beta, /* host or device pointer */
212               const Scalar* B, int ldb, Scalar* C,
213               int ldc) const TF_MUST_USE_RESULT;
214 
215   // Computes the Cholesky factorization A = L * L^H for a single matrix.
216   // Returns Status::OK() if the kernel was launched successfully. See:
217   // http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-potrf
218   template <typename Scalar>
219   Status Potrf(cublasFillMode_t uplo, int n, Scalar* dev_A, int lda,
220                int* dev_lapack_info) TF_MUST_USE_RESULT;
221 
222 #if CUDA_VERSION >= 9020
223   // Computes the Cholesky factorization A = L * L^H for a batch of small
224   // matrices.
225   // Returns Status::OK() if the kernel was launched successfully. See:
226   // http://docs.nvidia.com/cuda/cusolver/index.html#cuds-lt-t-gt-potrfBatched
227   template <typename Scalar>
228   Status PotrfBatched(cublasFillMode_t uplo, int n,
229                       const Scalar* const host_a_dev_ptrs[], int lda,
230                       DeviceLapackInfo* dev_lapack_info,
231                       int batch_size) TF_MUST_USE_RESULT;
232 #endif  // CUDA_VERSION >= 9020
233 
234   // LU factorization.
235   // Computes LU factorization with partial pivoting P * A = L * U.
236   // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-getrf
237   template <typename Scalar>
238   Status Getrf(int m, int n, Scalar* dev_A, int lda, int* dev_pivots,
239                int* dev_lapack_info) TF_MUST_USE_RESULT;
240 
241   // Uses LU factorization to solve A * X = B.
242   // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-getrs
243   template <typename Scalar>
244   Status Getrs(cublasOperation_t trans, int n, int nrhs, const Scalar* A,
245                int lda, const int* pivots, Scalar* B, int ldb,
246                int* dev_lapack_info) const TF_MUST_USE_RESULT;
247 
248   // Computes partially pivoted LU factorizations for a batch of small matrices.
249   // Returns Status::OK() if the kernel was launched successfully. See:
250   // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrfbatched
251   template <typename Scalar>
252   Status GetrfBatched(int n, const Scalar* const host_a_dev_ptrs[], int lda,
253                       int* dev_pivots, DeviceLapackInfo* dev_lapack_info,
254                       int batch_size) TF_MUST_USE_RESULT;
255 
256   // Batched linear solver using LU factorization from getrfBatched.
257   // Notice that lapack_info is returned on the host, as opposed to
258   // most of the other functions that return it on the device. See:
259   // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrsbatched
260   template <typename Scalar>
261   Status GetrsBatched(cublasOperation_t trans, int n, int nrhs,
262                       const Scalar* const dev_Aarray[], int lda,
263                       const int* devIpiv, const Scalar* const dev_Barray[],
264                       int ldb, int* host_lapack_info,
265                       int batch_size) TF_MUST_USE_RESULT;
266 
267   // Computes matrix inverses for a batch of small matrices. Uses the outputs
268   // from GetrfBatched. Returns Status::OK() if the kernel was launched
269   // successfully. See:
270   // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getribatched
271   template <typename Scalar>
272   Status GetriBatched(int n, const Scalar* const host_a_dev_ptrs[], int lda,
273                       const int* dev_pivots,
274                       const Scalar* const host_a_inverse_dev_ptrs[], int ldainv,
275                       DeviceLapackInfo* dev_lapack_info,
276                       int batch_size) TF_MUST_USE_RESULT;
277 
278   // Computes matrix inverses for a batch of small matrices with size n < 32.
279   // Returns Status::OK() if the kernel was launched successfully. See:
280   // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-matinvbatched
281   template <typename Scalar>
282   Status MatInvBatched(int n, const Scalar* const host_a_dev_ptrs[], int lda,
283                        const Scalar* const host_a_inverse_dev_ptrs[],
284                        int ldainv, DeviceLapackInfo* dev_lapack_info,
285                        int batch_size) TF_MUST_USE_RESULT;
286 
287   // QR factorization.
288   // Computes QR factorization A = Q * R.
289   // Returns Status::OK() if the kernel was launched successfully.
290   // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-geqrf
291   template <typename Scalar>
292   Status Geqrf(int m, int n, Scalar* dev_A, int lda, Scalar* dev_tau,
293                int* dev_lapack_info) TF_MUST_USE_RESULT;
294 
295   // Overwrite matrix C by product of C and the unitary Householder matrix Q.
296   // The Householder matrix Q is represented by the output from Geqrf in dev_a
297   // and dev_tau.
298   // Notice: If Scalar is real, only trans=CUBLAS_OP_N or trans=CUBLAS_OP_T is
299   // supported. If Scalar is complex, trans=CUBLAS_OP_N or trans=CUBLAS_OP_C is
300   // supported.
301   // Returns Status::OK() if the kernel was launched successfully.
302   // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-ormqr
303   template <typename Scalar>
304   Status Unmqr(cublasSideMode_t side, cublasOperation_t trans, int m, int n,
305                int k, const Scalar* dev_a, int lda, const Scalar* dev_tau,
306                Scalar* dev_c, int ldc, int* dev_lapack_info) TF_MUST_USE_RESULT;
307 
308   // Overwrites QR factorization produced by Geqrf by the unitary Householder
309   // matrix Q. On input, the Householder matrix Q is represented by the output
310   // from Geqrf in dev_a and dev_tau. On output, dev_a is overwritten with the
311   // first n columns of Q. Requires m >= n >= 0.
312   // Returns Status::OK() if the kernel was launched successfully.
313   // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-orgqr
314   template <typename Scalar>
315   Status Ungqr(int m, int n, int k, Scalar* dev_a, int lda,
316                const Scalar* dev_tau, int* dev_lapack_info) TF_MUST_USE_RESULT;
317 
318   // Hermitian (Symmetric) Eigen decomposition.
319   // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-syevd
320   template <typename Scalar>
321   Status Heevd(cusolverEigMode_t jobz, cublasFillMode_t uplo, int n,
322                Scalar* dev_A, int lda,
323                typename Eigen::NumTraits<Scalar>::Real* dev_W,
324                int* dev_lapack_info) TF_MUST_USE_RESULT;
325 
326   // Singular value decomposition.
327   // Returns Status::OK() if the kernel was launched successfully.
328   // TODO(rmlarsen, volunteers): Add support for complex types.
329   // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-gesvd
330   template <typename Scalar>
331   Status Gesvd(signed char jobu, signed char jobvt, int m, int n, Scalar* dev_A,
332                int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT,
333                int ldvt, int* dev_lapack_info) TF_MUST_USE_RESULT;
334   template <typename Scalar>
335   Status GesvdjBatched(cusolverEigMode_t jobz, int m, int n, Scalar* dev_A,
336                        int lda, Scalar* dev_S, Scalar* dev_U, int ldu,
337                        Scalar* dev_V, int ldv, int* dev_lapack_info,
338                        int batch_size);
339 
340   // Triangular solve
341   // Returns Status::OK() if the kernel was launched successfully.
342   // See https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-trsm
343   template <typename Scalar>
344   Status Trsm(cublasSideMode_t side, cublasFillMode_t uplo,
345               cublasOperation_t trans, cublasDiagType_t diag, int m, int n,
346               const Scalar* alpha, const Scalar* A, int lda, Scalar* B,
347               int ldb);
348 
349   template <typename Scalar>
350   Status Trsv(cublasFillMode_t uplo, cublasOperation_t trans,
351               cublasDiagType_t diag, int n, const Scalar* A, int lda, Scalar* x,
352               int intcx);
353 
354   // See
355   // https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-trsmbatched
356   template <typename Scalar>
357   Status TrsmBatched(cublasSideMode_t side, cublasFillMode_t uplo,
358                      cublasOperation_t trans, cublasDiagType_t diag, int m,
359                      int n, const Scalar* alpha,
360                      const Scalar* const dev_Aarray[], int lda,
361                      Scalar* dev_Barray[], int ldb, int batch_size);
362 
363  private:
364   OpKernelContext* context_;  // not owned.
365   cudaStream_t cuda_stream_;
366   cusolverDnHandle_t cusolver_dn_handle_;
367   cublasHandle_t cublas_handle_;
368   std::vector<TensorReference> scratch_tensor_refs_;
369 
370   TF_DISALLOW_COPY_AND_ASSIGN(CudaSolver);
371 };
372 #endif  // GOOGLE_CUDA
373 
374 // Helper class to allocate scratch memory and keep track of debug info.
375 // Mostly a thin wrapper around Tensor & allocate_temp.
376 template <typename Scalar>
377 class ScratchSpace {
378  public:
379   ScratchSpace(OpKernelContext* context, int64 size, bool on_host)
380       : ScratchSpace(context, TensorShape({size}), "", on_host) {}
381 
382   ScratchSpace(OpKernelContext* context, int64 size,
383                const std::string& debug_info, bool on_host)
384       : ScratchSpace(context, TensorShape({size}), debug_info, on_host) {}
385 
386   ScratchSpace(OpKernelContext* context, const TensorShape& shape,
387                const std::string& debug_info, bool on_host)
388       : context_(context), debug_info_(debug_info), on_host_(on_host) {
389     AllocatorAttributes alloc_attr;
390     if (on_host) {
391       // Allocate pinned memory on the host to avoid unnecessary
392       // synchronization.
393       alloc_attr.set_on_host(true);
394       alloc_attr.set_gpu_compatible(true);
395     }
396     TF_CHECK_OK(context->allocate_temp(DataTypeToEnum<Scalar>::value, shape,
397                                        &scratch_tensor_, alloc_attr));
398   }
399 
400   virtual ~ScratchSpace() {}
401 
402   Scalar* mutable_data() {
403     return scratch_tensor_.template flat<Scalar>().data();
404   }
405   const Scalar* data() const {
406     return scratch_tensor_.template flat<Scalar>().data();
407   }
408   Scalar& operator()(int64 i) {
409     return scratch_tensor_.template flat<Scalar>()(i);
410   }
411   const Scalar& operator()(int64 i) const {
412     return scratch_tensor_.template flat<Scalar>()(i);
413   }
414   int64 bytes() const { return scratch_tensor_.TotalBytes(); }
415   int64 size() const { return scratch_tensor_.NumElements(); }
416   const std::string& debug_info() const { return debug_info_; }
417 
418   Tensor& tensor() { return scratch_tensor_; }
419   const Tensor& tensor() const { return scratch_tensor_; }
420 
421   // Returns true if this ScratchSpace is in host memory.
422   bool on_host() const { return on_host_; }
423 
424  protected:
425   OpKernelContext* context() const { return context_; }
426 
427  private:
428   OpKernelContext* context_;  // not owned
429   const std::string debug_info_;
430   const bool on_host_;
431   Tensor scratch_tensor_;
432 };
433 
434 class HostLapackInfo : public ScratchSpace<int> {
435  public:
436   HostLapackInfo(OpKernelContext* context, int64 size,
437                  const std::string& debug_info)
438       : ScratchSpace<int>(context, size, debug_info, /* on_host */ true) {}
439 };
440 
441 class DeviceLapackInfo : public ScratchSpace<int> {
442  public:
443   DeviceLapackInfo(OpKernelContext* context, int64 size,
444                    const std::string& debug_info)
445       : ScratchSpace<int>(context, size, debug_info, /* on_host */ false) {}
446 
447   // Allocates a new scratch space on the host and launches a copy of the
448   // contents of *this to the new scratch space. Sets success to true if
449   // the copy kernel was launched successfully.
450   HostLapackInfo CopyToHost(bool* success) const {
451     CHECK(success != nullptr);
452     HostLapackInfo copy(context(), size(), debug_info());
453     auto stream = context()->op_device_context()->stream();
454     se::DeviceMemoryBase wrapped_src(
455         static_cast<void*>(const_cast<int*>(this->data())));
456     *success =
457         stream->ThenMemcpy(copy.mutable_data(), wrapped_src, this->bytes())
458             .ok();
459     return copy;
460   }
461 };
462 
463 #if GOOGLE_CUDA
464 template <typename Scalar>
465 ScratchSpace<Scalar> CudaSolver::GetScratchSpace(const TensorShape& shape,
466                                                  const std::string& debug_info,
467                                                  bool on_host) {
468   ScratchSpace<Scalar> new_scratch_space(context_, shape, debug_info, on_host);
469   scratch_tensor_refs_.emplace_back(new_scratch_space.tensor());
470   return std::move(new_scratch_space);
471 }
472 
473 template <typename Scalar>
474 ScratchSpace<Scalar> CudaSolver::GetScratchSpace(int64 size,
475                                                  const std::string& debug_info,
476                                                  bool on_host) {
477   return GetScratchSpace<Scalar>(TensorShape({size}), debug_info, on_host);
478 }
479 
480 inline DeviceLapackInfo CudaSolver::GetDeviceLapackInfo(
481     int64 size, const std::string& debug_info) {
482   DeviceLapackInfo new_dev_info(context_, size, debug_info);
483   scratch_tensor_refs_.emplace_back(new_dev_info.tensor());
484   return new_dev_info;
485 }
486 #endif  // GOOGLE_CUDA
487 
488 }  // namespace tensorflow
489 
490 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
491 
492 #endif  // TENSORFLOW_CORE_KERNELS_LINALG_CUDA_SOLVERS_H_
493