1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ============================================================================== 15 */ 16 17 #ifndef TENSORFLOW_CORE_KERNELS_LINALG_CUDA_SOLVERS_H_ 18 #define TENSORFLOW_CORE_KERNELS_LINALG_CUDA_SOLVERS_H_ 19 20 // This header declares the class CudaSolver, which contains wrappers of linear 21 // algebra solvers in the cuBlas and cuSolverDN libraries for use in TensorFlow 22 // kernels. 23 24 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 25 26 #include <functional> 27 #include <vector> 28 29 #if GOOGLE_CUDA 30 #include "third_party/gpus/cuda/include/cublas_v2.h" 31 #include "third_party/gpus/cuda/include/cuda.h" 32 #include "third_party/gpus/cuda/include/cusolverDn.h" 33 #endif 34 #include "tensorflow/core/framework/op_kernel.h" 35 #include "tensorflow/core/framework/tensor.h" 36 #include "tensorflow/core/framework/tensor_reference.h" 37 #include "tensorflow/core/lib/core/status.h" 38 #include "tensorflow/core/platform/stream_executor.h" 39 40 namespace tensorflow { 41 42 #if GOOGLE_CUDA 43 // Type traits to get CUDA complex types from std::complex<T>. 44 template <typename T> 45 struct CUDAComplexT { 46 typedef T type; 47 }; 48 template <> 49 struct CUDAComplexT<std::complex<float>> { 50 typedef cuComplex type; 51 }; 52 template <> 53 struct CUDAComplexT<std::complex<double>> { 54 typedef cuDoubleComplex type; 55 }; 56 // Converts pointers of std::complex<> to pointers of 57 // cuComplex/cuDoubleComplex. No type conversion for non-complex types. 58 template <typename T> 59 inline const typename CUDAComplexT<T>::type* CUDAComplex(const T* p) { 60 return reinterpret_cast<const typename CUDAComplexT<T>::type*>(p); 61 } 62 template <typename T> 63 inline typename CUDAComplexT<T>::type* CUDAComplex(T* p) { 64 return reinterpret_cast<typename CUDAComplexT<T>::type*>(p); 65 } 66 67 // Template to give the Cublas adjoint operation for real and complex types. 68 template <typename T> 69 cublasOperation_t CublasAdjointOp() { 70 return Eigen::NumTraits<T>::IsComplex ? CUBLAS_OP_C : CUBLAS_OP_T; 71 } 72 73 // Container of LAPACK info data (an array of int) generated on-device by 74 // a CudaSolver call. One or more such objects can be passed to 75 // CudaSolver::CopyLapackInfoToHostAsync() along with a callback to 76 // check the LAPACK info data after the corresponding kernels 77 // finish and LAPACK info has been copied from the device to the host. 78 class DeviceLapackInfo; 79 80 // Host-side copy of LAPACK info. 81 class HostLapackInfo; 82 83 // The CudaSolver class provides a simplified templated API for the dense linear 84 // solvers implemented in cuSolverDN (http://docs.nvidia.com/cuda/cusolver) and 85 // cuBlas (http://docs.nvidia.com/cuda/cublas/#blas-like-extension/). 86 // An object of this class wraps static cuSolver and cuBlas instances, 87 // and will launch Cuda kernels on the stream wrapped by the GPU device 88 // in the OpKernelContext provided to the constructor. 89 // 90 // Notice: All the computational member functions are asynchronous and simply 91 // launch one or more Cuda kernels on the Cuda stream wrapped by the CudaSolver 92 // object. To check the final status of the kernels run, call 93 // CopyLapackInfoToHostAsync() on the CudaSolver object to set a callback that 94 // will be invoked with the status of the kernels launched thus far as 95 // arguments. 96 // 97 // Example of an asynchronous TensorFlow kernel using CudaSolver: 98 // 99 // template <typename Scalar> 100 // class SymmetricPositiveDefiniteSolveOpGpu : public AsyncOpKernel { 101 // public: 102 // explicit SymmetricPositiveDefiniteSolveOpGpu(OpKernelConstruction* context) 103 // : AsyncOpKernel(context) { } 104 // void ComputeAsync(OpKernelContext* context, DoneCallback done) final { 105 // // 1. Set up input and output device ptrs. See, e.g., 106 // // matrix_inverse_op.cc for a full example. 107 // ... 108 // 109 // // 2. Initialize the solver object. 110 // std::unique_ptr<CudaSolver> solver(new CudaSolver(context)); 111 // 112 // // 3. Launch the two compute kernels back to back on the stream without 113 // // synchronizing. 114 // std::vector<DeviceLapackInfo> dev_info; 115 // const int batch_size = 1; 116 // dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "potrf"); 117 // // Compute the Cholesky decomposition of the input matrix. 118 // OP_REQUIRES_OK_ASYNC(context, 119 // solver->Potrf(uplo, n, dev_matrix_ptrs, n, 120 // dev_info.back().mutable_data()), 121 // done); 122 // dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "potrs"); 123 // // Use the Cholesky decomposition of the input matrix to solve A X = RHS. 124 // OP_REQUIRES_OK_ASYNC(context, 125 // solver->Potrs(uplo, n, nrhs, dev_matrix_ptrs, n, 126 // dev_output_ptrs, ldrhs, 127 // dev_info.back().mutable_data()), 128 // done); 129 // 130 // // 4. Check the status after the computation finishes and call done. 131 // solver.CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info, 132 // std::move(done)); 133 // } 134 // }; 135 136 template <typename Scalar> 137 class ScratchSpace; 138 139 class CudaSolver { 140 public: 141 // This object stores a pointer to context, which must outlive it. 142 explicit CudaSolver(OpKernelContext* context); 143 virtual ~CudaSolver(); 144 145 // Launches a memcpy of solver status data specified by dev_lapack_info from 146 // device to the host, and asynchronously invokes the given callback when the 147 // copy is complete. The first Status argument to the callback will be 148 // Status::OK if all lapack infos retrieved are zero, otherwise an error 149 // status is given. The second argument contains a host-side copy of the 150 // entire set of infos retrieved, and can be used for generating detailed 151 // error messages. 152 // `info_checker_callback` must call the DoneCallback of any asynchronous 153 // OpKernel within which `solver` is used. 154 static void CheckLapackInfoAndDeleteSolverAsync( 155 std::unique_ptr<CudaSolver> solver, 156 const std::vector<DeviceLapackInfo>& dev_lapack_info, 157 std::function<void(const Status&, const std::vector<HostLapackInfo>&)> 158 info_checker_callback); 159 160 // Simpler version to use if no special error checking / messages are needed 161 // apart from checking that the Status of all calls was Status::OK. 162 // `done` may be nullptr. 163 static void CheckLapackInfoAndDeleteSolverAsync( 164 std::unique_ptr<CudaSolver> solver, 165 const std::vector<DeviceLapackInfo>& dev_lapack_info, 166 AsyncOpKernel::DoneCallback done); 167 168 // Returns a ScratchSpace. The CudaSolver object maintains a TensorReference 169 // to the underlying Tensor to prevent it from being deallocated prematurely. 170 template <typename Scalar> 171 ScratchSpace<Scalar> GetScratchSpace(const TensorShape& shape, 172 const std::string& debug_info, 173 bool on_host); 174 template <typename Scalar> 175 ScratchSpace<Scalar> GetScratchSpace(int64 size, 176 const std::string& debug_info, 177 bool on_host); 178 // Returns a DeviceLapackInfo that will live for the duration of the 179 // CudaSolver object. 180 inline DeviceLapackInfo GetDeviceLapackInfo(int64 size, 181 const std::string& debug_info); 182 183 // Allocates a temporary tensor that will live for the duration of the 184 // CudaSolver object. 185 Status allocate_scoped_tensor(DataType type, const TensorShape& shape, 186 Tensor* scoped_tensor); 187 Status forward_input_or_allocate_scoped_tensor( 188 gtl::ArraySlice<int> candidate_input_indices, DataType type, 189 const TensorShape& shape, Tensor* input_alias_or_new_scoped_tensor); 190 191 OpKernelContext* context() { return context_; } 192 193 // ==================================================================== 194 // Wrappers for cuSolverDN and cuBlas solvers start here. 195 // 196 // Apart from capitalization of the first letter, the method names below 197 // map to those in cuSolverDN and cuBlas, which follow the naming 198 // convention in LAPACK see, e.g., 199 // http://docs.nvidia.com/cuda/cusolver/#naming-convention 200 201 // This function performs the matrix-matrix addition/transposition 202 // C = alpha * op(A) + beta * op(B). 203 // Returns Status::OK() if the kernel was launched successfully. See: 204 // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-geam 205 // NOTE(ebrevdo): Does not support in-place transpose of non-square 206 // matrices. 207 template <typename Scalar> 208 Status Geam(cublasOperation_t transa, cublasOperation_t transb, int m, int n, 209 const Scalar* alpha, /* host or device pointer */ 210 const Scalar* A, int lda, 211 const Scalar* beta, /* host or device pointer */ 212 const Scalar* B, int ldb, Scalar* C, 213 int ldc) const TF_MUST_USE_RESULT; 214 215 // Computes the Cholesky factorization A = L * L^H for a single matrix. 216 // Returns Status::OK() if the kernel was launched successfully. See: 217 // http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-potrf 218 template <typename Scalar> 219 Status Potrf(cublasFillMode_t uplo, int n, Scalar* dev_A, int lda, 220 int* dev_lapack_info) TF_MUST_USE_RESULT; 221 222 #if CUDA_VERSION >= 9020 223 // Computes the Cholesky factorization A = L * L^H for a batch of small 224 // matrices. 225 // Returns Status::OK() if the kernel was launched successfully. See: 226 // http://docs.nvidia.com/cuda/cusolver/index.html#cuds-lt-t-gt-potrfBatched 227 template <typename Scalar> 228 Status PotrfBatched(cublasFillMode_t uplo, int n, 229 const Scalar* const host_a_dev_ptrs[], int lda, 230 DeviceLapackInfo* dev_lapack_info, 231 int batch_size) TF_MUST_USE_RESULT; 232 #endif // CUDA_VERSION >= 9020 233 234 // LU factorization. 235 // Computes LU factorization with partial pivoting P * A = L * U. 236 // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-getrf 237 template <typename Scalar> 238 Status Getrf(int m, int n, Scalar* dev_A, int lda, int* dev_pivots, 239 int* dev_lapack_info) TF_MUST_USE_RESULT; 240 241 // Uses LU factorization to solve A * X = B. 242 // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-getrs 243 template <typename Scalar> 244 Status Getrs(cublasOperation_t trans, int n, int nrhs, const Scalar* A, 245 int lda, const int* pivots, Scalar* B, int ldb, 246 int* dev_lapack_info) const TF_MUST_USE_RESULT; 247 248 // Computes partially pivoted LU factorizations for a batch of small matrices. 249 // Returns Status::OK() if the kernel was launched successfully. See: 250 // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrfbatched 251 template <typename Scalar> 252 Status GetrfBatched(int n, const Scalar* const host_a_dev_ptrs[], int lda, 253 int* dev_pivots, DeviceLapackInfo* dev_lapack_info, 254 int batch_size) TF_MUST_USE_RESULT; 255 256 // Batched linear solver using LU factorization from getrfBatched. 257 // Notice that lapack_info is returned on the host, as opposed to 258 // most of the other functions that return it on the device. See: 259 // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrsbatched 260 template <typename Scalar> 261 Status GetrsBatched(cublasOperation_t trans, int n, int nrhs, 262 const Scalar* const dev_Aarray[], int lda, 263 const int* devIpiv, const Scalar* const dev_Barray[], 264 int ldb, int* host_lapack_info, 265 int batch_size) TF_MUST_USE_RESULT; 266 267 // Computes matrix inverses for a batch of small matrices. Uses the outputs 268 // from GetrfBatched. Returns Status::OK() if the kernel was launched 269 // successfully. See: 270 // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getribatched 271 template <typename Scalar> 272 Status GetriBatched(int n, const Scalar* const host_a_dev_ptrs[], int lda, 273 const int* dev_pivots, 274 const Scalar* const host_a_inverse_dev_ptrs[], int ldainv, 275 DeviceLapackInfo* dev_lapack_info, 276 int batch_size) TF_MUST_USE_RESULT; 277 278 // Computes matrix inverses for a batch of small matrices with size n < 32. 279 // Returns Status::OK() if the kernel was launched successfully. See: 280 // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-matinvbatched 281 template <typename Scalar> 282 Status MatInvBatched(int n, const Scalar* const host_a_dev_ptrs[], int lda, 283 const Scalar* const host_a_inverse_dev_ptrs[], 284 int ldainv, DeviceLapackInfo* dev_lapack_info, 285 int batch_size) TF_MUST_USE_RESULT; 286 287 // QR factorization. 288 // Computes QR factorization A = Q * R. 289 // Returns Status::OK() if the kernel was launched successfully. 290 // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-geqrf 291 template <typename Scalar> 292 Status Geqrf(int m, int n, Scalar* dev_A, int lda, Scalar* dev_tau, 293 int* dev_lapack_info) TF_MUST_USE_RESULT; 294 295 // Overwrite matrix C by product of C and the unitary Householder matrix Q. 296 // The Householder matrix Q is represented by the output from Geqrf in dev_a 297 // and dev_tau. 298 // Notice: If Scalar is real, only trans=CUBLAS_OP_N or trans=CUBLAS_OP_T is 299 // supported. If Scalar is complex, trans=CUBLAS_OP_N or trans=CUBLAS_OP_C is 300 // supported. 301 // Returns Status::OK() if the kernel was launched successfully. 302 // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-ormqr 303 template <typename Scalar> 304 Status Unmqr(cublasSideMode_t side, cublasOperation_t trans, int m, int n, 305 int k, const Scalar* dev_a, int lda, const Scalar* dev_tau, 306 Scalar* dev_c, int ldc, int* dev_lapack_info) TF_MUST_USE_RESULT; 307 308 // Overwrites QR factorization produced by Geqrf by the unitary Householder 309 // matrix Q. On input, the Householder matrix Q is represented by the output 310 // from Geqrf in dev_a and dev_tau. On output, dev_a is overwritten with the 311 // first n columns of Q. Requires m >= n >= 0. 312 // Returns Status::OK() if the kernel was launched successfully. 313 // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-orgqr 314 template <typename Scalar> 315 Status Ungqr(int m, int n, int k, Scalar* dev_a, int lda, 316 const Scalar* dev_tau, int* dev_lapack_info) TF_MUST_USE_RESULT; 317 318 // Hermitian (Symmetric) Eigen decomposition. 319 // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-syevd 320 template <typename Scalar> 321 Status Heevd(cusolverEigMode_t jobz, cublasFillMode_t uplo, int n, 322 Scalar* dev_A, int lda, 323 typename Eigen::NumTraits<Scalar>::Real* dev_W, 324 int* dev_lapack_info) TF_MUST_USE_RESULT; 325 326 // Singular value decomposition. 327 // Returns Status::OK() if the kernel was launched successfully. 328 // TODO(rmlarsen, volunteers): Add support for complex types. 329 // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-gesvd 330 template <typename Scalar> 331 Status Gesvd(signed char jobu, signed char jobvt, int m, int n, Scalar* dev_A, 332 int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT, 333 int ldvt, int* dev_lapack_info) TF_MUST_USE_RESULT; 334 template <typename Scalar> 335 Status GesvdjBatched(cusolverEigMode_t jobz, int m, int n, Scalar* dev_A, 336 int lda, Scalar* dev_S, Scalar* dev_U, int ldu, 337 Scalar* dev_V, int ldv, int* dev_lapack_info, 338 int batch_size); 339 340 // Triangular solve 341 // Returns Status::OK() if the kernel was launched successfully. 342 // See https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-trsm 343 template <typename Scalar> 344 Status Trsm(cublasSideMode_t side, cublasFillMode_t uplo, 345 cublasOperation_t trans, cublasDiagType_t diag, int m, int n, 346 const Scalar* alpha, const Scalar* A, int lda, Scalar* B, 347 int ldb); 348 349 template <typename Scalar> 350 Status Trsv(cublasFillMode_t uplo, cublasOperation_t trans, 351 cublasDiagType_t diag, int n, const Scalar* A, int lda, Scalar* x, 352 int intcx); 353 354 // See 355 // https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-trsmbatched 356 template <typename Scalar> 357 Status TrsmBatched(cublasSideMode_t side, cublasFillMode_t uplo, 358 cublasOperation_t trans, cublasDiagType_t diag, int m, 359 int n, const Scalar* alpha, 360 const Scalar* const dev_Aarray[], int lda, 361 Scalar* dev_Barray[], int ldb, int batch_size); 362 363 private: 364 OpKernelContext* context_; // not owned. 365 cudaStream_t cuda_stream_; 366 cusolverDnHandle_t cusolver_dn_handle_; 367 cublasHandle_t cublas_handle_; 368 std::vector<TensorReference> scratch_tensor_refs_; 369 370 TF_DISALLOW_COPY_AND_ASSIGN(CudaSolver); 371 }; 372 #endif // GOOGLE_CUDA 373 374 // Helper class to allocate scratch memory and keep track of debug info. 375 // Mostly a thin wrapper around Tensor & allocate_temp. 376 template <typename Scalar> 377 class ScratchSpace { 378 public: 379 ScratchSpace(OpKernelContext* context, int64 size, bool on_host) 380 : ScratchSpace(context, TensorShape({size}), "", on_host) {} 381 382 ScratchSpace(OpKernelContext* context, int64 size, 383 const std::string& debug_info, bool on_host) 384 : ScratchSpace(context, TensorShape({size}), debug_info, on_host) {} 385 386 ScratchSpace(OpKernelContext* context, const TensorShape& shape, 387 const std::string& debug_info, bool on_host) 388 : context_(context), debug_info_(debug_info), on_host_(on_host) { 389 AllocatorAttributes alloc_attr; 390 if (on_host) { 391 // Allocate pinned memory on the host to avoid unnecessary 392 // synchronization. 393 alloc_attr.set_on_host(true); 394 alloc_attr.set_gpu_compatible(true); 395 } 396 TF_CHECK_OK(context->allocate_temp(DataTypeToEnum<Scalar>::value, shape, 397 &scratch_tensor_, alloc_attr)); 398 } 399 400 virtual ~ScratchSpace() {} 401 402 Scalar* mutable_data() { 403 return scratch_tensor_.template flat<Scalar>().data(); 404 } 405 const Scalar* data() const { 406 return scratch_tensor_.template flat<Scalar>().data(); 407 } 408 Scalar& operator()(int64 i) { 409 return scratch_tensor_.template flat<Scalar>()(i); 410 } 411 const Scalar& operator()(int64 i) const { 412 return scratch_tensor_.template flat<Scalar>()(i); 413 } 414 int64 bytes() const { return scratch_tensor_.TotalBytes(); } 415 int64 size() const { return scratch_tensor_.NumElements(); } 416 const std::string& debug_info() const { return debug_info_; } 417 418 Tensor& tensor() { return scratch_tensor_; } 419 const Tensor& tensor() const { return scratch_tensor_; } 420 421 // Returns true if this ScratchSpace is in host memory. 422 bool on_host() const { return on_host_; } 423 424 protected: 425 OpKernelContext* context() const { return context_; } 426 427 private: 428 OpKernelContext* context_; // not owned 429 const std::string debug_info_; 430 const bool on_host_; 431 Tensor scratch_tensor_; 432 }; 433 434 class HostLapackInfo : public ScratchSpace<int> { 435 public: 436 HostLapackInfo(OpKernelContext* context, int64 size, 437 const std::string& debug_info) 438 : ScratchSpace<int>(context, size, debug_info, /* on_host */ true) {} 439 }; 440 441 class DeviceLapackInfo : public ScratchSpace<int> { 442 public: 443 DeviceLapackInfo(OpKernelContext* context, int64 size, 444 const std::string& debug_info) 445 : ScratchSpace<int>(context, size, debug_info, /* on_host */ false) {} 446 447 // Allocates a new scratch space on the host and launches a copy of the 448 // contents of *this to the new scratch space. Sets success to true if 449 // the copy kernel was launched successfully. 450 HostLapackInfo CopyToHost(bool* success) const { 451 CHECK(success != nullptr); 452 HostLapackInfo copy(context(), size(), debug_info()); 453 auto stream = context()->op_device_context()->stream(); 454 se::DeviceMemoryBase wrapped_src( 455 static_cast<void*>(const_cast<int*>(this->data()))); 456 *success = 457 stream->ThenMemcpy(copy.mutable_data(), wrapped_src, this->bytes()) 458 .ok(); 459 return copy; 460 } 461 }; 462 463 #if GOOGLE_CUDA 464 template <typename Scalar> 465 ScratchSpace<Scalar> CudaSolver::GetScratchSpace(const TensorShape& shape, 466 const std::string& debug_info, 467 bool on_host) { 468 ScratchSpace<Scalar> new_scratch_space(context_, shape, debug_info, on_host); 469 scratch_tensor_refs_.emplace_back(new_scratch_space.tensor()); 470 return std::move(new_scratch_space); 471 } 472 473 template <typename Scalar> 474 ScratchSpace<Scalar> CudaSolver::GetScratchSpace(int64 size, 475 const std::string& debug_info, 476 bool on_host) { 477 return GetScratchSpace<Scalar>(TensorShape({size}), debug_info, on_host); 478 } 479 480 inline DeviceLapackInfo CudaSolver::GetDeviceLapackInfo( 481 int64 size, const std::string& debug_info) { 482 DeviceLapackInfo new_dev_info(context_, size, debug_info); 483 scratch_tensor_refs_.emplace_back(new_dev_info.tensor()); 484 return new_dev_info; 485 } 486 #endif // GOOGLE_CUDA 487 488 } // namespace tensorflow 489 490 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 491 492 #endif // TENSORFLOW_CORE_KERNELS_LINALG_CUDA_SOLVERS_H_ 493