1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ============================================================================== 15 */ 16 17 #ifndef TENSORFLOW_CORE_KERNELS_CUDA_SOLVERS_H_ 18 #define TENSORFLOW_CORE_KERNELS_CUDA_SOLVERS_H_ 19 20 // This header declares the class CudaSolver, which contains wrappers of linear 21 // algebra solvers in the cuBlas and cuSolverDN libraries for use in TensorFlow 22 // kernels. 23 24 #ifdef GOOGLE_CUDA 25 26 #include <functional> 27 #include <vector> 28 29 #include "cuda/include/cublas_v2.h" 30 #include "cuda/include/cusolverDn.h" 31 #include "tensorflow/core/framework/op_kernel.h" 32 #include "tensorflow/core/framework/tensor.h" 33 #include "tensorflow/core/lib/core/status.h" 34 #include "tensorflow/core/platform/stream_executor.h" 35 36 namespace tensorflow { 37 38 // Type traits to get CUDA complex types from std::complex<T>. 39 template <typename T> 40 struct CUDAComplexT { 41 typedef T type; 42 }; 43 template <> 44 struct CUDAComplexT<std::complex<float>> { 45 typedef cuComplex type; 46 }; 47 template <> 48 struct CUDAComplexT<std::complex<double>> { 49 typedef cuDoubleComplex type; 50 }; 51 // Converts pointers of std::complex<> to pointers of 52 // cuComplex/cuDoubleComplex. No type conversion for non-complex types. 53 template <typename T> 54 inline const typename CUDAComplexT<T>::type* CUDAComplex(const T* p) { 55 return reinterpret_cast<const typename CUDAComplexT<T>::type*>(p); 56 } 57 template <typename T> 58 inline typename CUDAComplexT<T>::type* CUDAComplex(T* p) { 59 return reinterpret_cast<typename CUDAComplexT<T>::type*>(p); 60 } 61 62 // Template to give the Cublas adjoint operation for real and complex types. 63 template <typename T> 64 cublasOperation_t CublasAdjointOp() { 65 return Eigen::NumTraits<T>::IsComplex ? CUBLAS_OP_C : CUBLAS_OP_T; 66 } 67 68 // Container of LAPACK info data (an array of int) generated on-device by 69 // a CudaSolver call. One or more such objects can be passed to 70 // CudaSolver::CopyLapackInfoToHostAsync() along with a callback to 71 // check the LAPACK info data after the corresponding kernels 72 // finish and LAPACK info has been copied from the device to the host. 73 class DeviceLapackInfo; 74 75 // Host-side copy of LAPACK info. 76 class HostLapackInfo; 77 78 // The CudaSolver class provides a simplified templated API for the dense linear 79 // solvers implemented in cuSolverDN (http://docs.nvidia.com/cuda/cusolver) and 80 // cuBlas (http://docs.nvidia.com/cuda/cublas/#blas-like-extension/). 81 // An object of this class wraps static cuSolver and cuBlas instances, 82 // and will launch Cuda kernels on the stream wrapped by the GPU device 83 // in the OpKernelContext provided to the constructor. 84 // 85 // Notice: All the computational member functions are asynchronous and simply 86 // launch one or more Cuda kernels on the Cuda stream wrapped by the CudaSolver 87 // object. To check the final status of the kernels run, call 88 // CopyLapackInfoToHostAsync() on the CudaSolver object to set a callback that 89 // will be invoked with the status of the kernels launched thus far as 90 // arguments. 91 // 92 // Example of an asynchronous TensorFlow kernel using CudaSolver: 93 // 94 // template <typename Scalar> 95 // class SymmetricPositiveDefiniteSolveOpGpu : public AsyncOpKernel { 96 // public: 97 // explicit SymmetricPositiveDefiniteSolveOpGpu(OpKernelConstruction* context) 98 // : AsyncOpKernel(context) { } 99 // void ComputeAsync(OpKernelContext* context, DoneCallback done) final { 100 // // 1. Set up input and output device ptrs. See, e.g., 101 // // matrix_inverse_op.cc for a full example. 102 // ... 103 // 104 // // 2. Initialize the solver object. 105 // std::unique_ptr<CudaSolver> solver(new CudaSolver(context)); 106 // 107 // // 3. Launch the two compute kernels back to back on the stream without 108 // // synchronizing. 109 // std::vector<DeviceLapackInfo> dev_info; 110 // const int batch_size = 1; 111 // dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "potrf"); 112 // // Compute the Cholesky decomposition of the input matrix. 113 // OP_REQUIRES_OK_ASYNC(context, 114 // solver->Potrf(uplo, n, dev_matrix_ptrs, n, 115 // dev_info.back().mutable_data()), 116 // done); 117 // dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "potrs"); 118 // // Use the Cholesky decomposition of the input matrix to solve A X = RHS. 119 // OP_REQUIRES_OK_ASYNC(context, 120 // solver->Potrs(uplo, n, nrhs, dev_matrix_ptrs, n, 121 // dev_output_ptrs, ldrhs, 122 // dev_info.back().mutable_data()), 123 // done); 124 // 125 // // 4. Check the status after the computation finishes and call done. 126 // solver.CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info, 127 // std::move(done)); 128 // } 129 // }; 130 131 template <typename Scalar> 132 class ScratchSpace; 133 134 class CudaSolver { 135 public: 136 // This object stores a pointer to context, which must outlive it. 137 explicit CudaSolver(OpKernelContext* context); 138 virtual ~CudaSolver(); 139 140 // Launches a memcpy of solver status data specified by dev_lapack_info from 141 // device to the host, and asynchronously invokes the given callback when the 142 // copy is complete. The first Status argument to the callback will be 143 // Status::OK if all lapack infos retrieved are zero, otherwise an error 144 // status is given. The second argument contains a host-side copy of the 145 // entire set of infos retrieved, and can be used for generating detailed 146 // error messages. 147 // `info_checker_callback` must call the DoneCallback of any asynchronous 148 // OpKernel within which `solver` is used. 149 static void CheckLapackInfoAndDeleteSolverAsync( 150 std::unique_ptr<CudaSolver> solver, 151 const std::vector<DeviceLapackInfo>& dev_lapack_info, 152 std::function<void(const Status&, const std::vector<HostLapackInfo>&)> 153 info_checker_callback); 154 155 // Simpler version to use if no special error checking / messages are needed 156 // apart from checking that the Status of all calls was Status::OK. 157 // `done` may be nullptr. 158 static void CheckLapackInfoAndDeleteSolverAsync( 159 std::unique_ptr<CudaSolver> solver, 160 const std::vector<DeviceLapackInfo>& dev_lapack_info, 161 AsyncOpKernel::DoneCallback done); 162 163 // Returns a ScratchSpace. The CudaSolver object maintains a TensorReference 164 // to the underlying Tensor to prevent it from being deallocated prematurely. 165 template <typename Scalar> 166 ScratchSpace<Scalar> GetScratchSpace(const TensorShape& shape, 167 const string& debug_info, bool on_host); 168 template <typename Scalar> 169 ScratchSpace<Scalar> GetScratchSpace(int64 size, const string& debug_info, 170 bool on_host); 171 // Returns a DeviceLapackInfo that will live for the duration of the 172 // CudaSolver object. 173 inline DeviceLapackInfo GetDeviceLapackInfo(int64 size, 174 const string& debug_info); 175 176 // Allocates a temporary tensor that will live for the duration of the 177 // CudaSolver object. 178 Status allocate_scoped_tensor(DataType type, const TensorShape& shape, 179 Tensor* scoped_tensor); 180 Status forward_input_or_allocate_scoped_tensor( 181 gtl::ArraySlice<int> candidate_input_indices, DataType type, 182 const TensorShape& shape, Tensor* input_alias_or_new_scoped_tensor); 183 184 OpKernelContext* context() { return context_; } 185 186 // ==================================================================== 187 // Wrappers for cuSolverDN and cuBlas solvers start here. 188 // 189 // Apart from capitalization of the first letter, the method names below 190 // map to those in cuSolverDN and cuBlas, which follow the naming 191 // convention in LAPACK see, e.g., 192 // http://docs.nvidia.com/cuda/cusolver/#naming-convention 193 194 // This function performs the matrix-matrix addition/transposition 195 // C = alpha * op(A) + beta * op(B). 196 // Returns Status::OK() if the kernel was launched successfully. See: 197 // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-geam 198 // NOTE(ebrevdo): Does not support in-place transpose of non-square 199 // matrices. 200 template <typename Scalar> 201 Status Geam(cublasOperation_t transa, cublasOperation_t transb, int m, int n, 202 const Scalar* alpha, /* host or device pointer */ 203 const Scalar* A, int lda, 204 const Scalar* beta, /* host or device pointer */ 205 const Scalar* B, int ldb, Scalar* C, 206 int ldc) const TF_MUST_USE_RESULT; 207 208 // Computes the Cholesky factorization A = L * L^T for a single matrix. 209 // Returns Status::OK() if the kernel was launched successfully. See: 210 // http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-potrf 211 template <typename Scalar> 212 Status Potrf(cublasFillMode_t uplo, int n, Scalar* dev_A, int lda, 213 int* dev_lapack_info) TF_MUST_USE_RESULT; 214 215 // LU factorization. 216 // Computes LU factorization with partial pivoting P * A = L * U. 217 // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-getrf 218 template <typename Scalar> 219 Status Getrf(int m, int n, Scalar* dev_A, int lda, int* dev_pivots, 220 int* dev_lapack_info) TF_MUST_USE_RESULT; 221 222 // Uses LU factorization to solve A * X = B. 223 // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-getrs 224 template <typename Scalar> 225 Status Getrs(cublasOperation_t trans, int n, int nrhs, const Scalar* A, 226 int lda, const int* pivots, Scalar* B, int ldb, 227 int* dev_lapack_info) const TF_MUST_USE_RESULT; 228 229 // Computes partially pivoted LU factorizations for a batch of small matrices. 230 // Returns Status::OK() if the kernel was launched successfully.See: 231 // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrfbatched 232 template <typename Scalar> 233 Status GetrfBatched(int n, const Scalar* const host_a_dev_ptrs[], int lda, 234 int* dev_pivots, DeviceLapackInfo* dev_lapack_info, 235 int batch_size) TF_MUST_USE_RESULT; 236 237 // Batched linear solver using LU factorization from getrfBatched. 238 // Notice that lapack_info is returned on the host, as opposed to 239 // most of the other functions that return it on the device. See: 240 // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrsbatched 241 template <typename Scalar> 242 Status GetrsBatched(cublasOperation_t trans, int n, int nrhs, 243 const Scalar* const dev_Aarray[], int lda, 244 const int* devIpiv, const Scalar* const dev_Barray[], 245 int ldb, int* host_lapack_info, 246 int batch_size) TF_MUST_USE_RESULT; 247 248 // Computes matrix inverses for a batch of small matrices. Uses the outputs 249 // from GetrfBatched. Returns Status::OK() if the kernel was launched 250 // successfully. See: 251 // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getribatched 252 template <typename Scalar> 253 Status GetriBatched(int n, const Scalar* const host_a_dev_ptrs[], int lda, 254 const int* dev_pivots, 255 const Scalar* const host_a_inverse_dev_ptrs[], int ldainv, 256 DeviceLapackInfo* dev_lapack_info, 257 int batch_size) TF_MUST_USE_RESULT; 258 259 // Computes matrix inverses for a batch of small matrices with size n < 32. 260 // Returns Status::OK() if the kernel was launched successfully. See: 261 // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-matinvbatched 262 template <typename Scalar> 263 Status MatInvBatched(int n, const Scalar* const host_a_dev_ptrs[], int lda, 264 const Scalar* const host_a_inverse_dev_ptrs[], 265 int ldainv, DeviceLapackInfo* dev_lapack_info, 266 int batch_size) TF_MUST_USE_RESULT; 267 268 // QR factorization. 269 // Computes QR factorization A = Q * R. 270 // Returns Status::OK() if the kernel was launched successfully. 271 // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-geqrf 272 template <typename Scalar> 273 Status Geqrf(int m, int n, Scalar* dev_A, int lda, Scalar* dev_tau, 274 int* dev_lapack_info) TF_MUST_USE_RESULT; 275 276 // Overwrite matrix C by product of C and the unitary Householder matrix Q. 277 // The Householder matrix Q is represented by the output from Geqrf in dev_a 278 // and dev_tau. 279 // Notice: If Scalar is real, only trans=CUBLAS_OP_N or trans=CUBLAS_OP_T is 280 // supported. If Scalar is complex, trans=CUBLAS_OP_N or trans=CUBLAS_OP_C is 281 // supported. 282 // Returns Status::OK() if the kernel was launched successfully. 283 // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-ormqr 284 template <typename Scalar> 285 Status Unmqr(cublasSideMode_t side, cublasOperation_t trans, int m, int n, 286 int k, const Scalar* dev_a, int lda, const Scalar* dev_tau, 287 Scalar* dev_c, int ldc, int* dev_lapack_info) TF_MUST_USE_RESULT; 288 289 // Overwrites QR factorization produced by Geqrf by the unitary Householder 290 // matrix Q. On input, the Householder matrix Q is represented by the output 291 // from Geqrf in dev_a and dev_tau. On output, dev_a is overwritten with the 292 // first n columns of Q. Requires m >= n >= 0. 293 // Returns Status::OK() if the kernel was launched successfully. 294 // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-orgqr 295 template <typename Scalar> 296 Status Ungqr(int m, int n, int k, Scalar* dev_a, int lda, 297 const Scalar* dev_tau, int* dev_lapack_info) TF_MUST_USE_RESULT; 298 299 // Hermitian (Symmetric) Eigen decomposition. 300 // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-syevd 301 template <typename Scalar> 302 Status Heevd(cusolverEigMode_t jobz, cublasFillMode_t uplo, int n, 303 Scalar* dev_A, int lda, 304 typename Eigen::NumTraits<Scalar>::Real* dev_W, 305 int* dev_lapack_info) TF_MUST_USE_RESULT; 306 307 // Singular value decomposition. 308 // Returns Status::OK() if the kernel was launched successfully. 309 // TODO(rmlarsen, volunteers): Add support for complex types. 310 // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-gesvd 311 template <typename Scalar> 312 Status Gesvd(signed char jobu, signed char jobvt, int m, int n, Scalar* dev_A, 313 int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT, 314 int ldvt, int* dev_lapack_info) TF_MUST_USE_RESULT; 315 template <typename Scalar> 316 Status GesvdjBatched(cusolverEigMode_t jobz, int m, int n, Scalar* dev_A, 317 int lda, Scalar* dev_S, Scalar* dev_U, int ldu, 318 Scalar* dev_V, int ldv, int* dev_lapack_info, 319 int batch_size); 320 321 private: 322 OpKernelContext* context_; // not owned. 323 cudaStream_t cuda_stream_; 324 cusolverDnHandle_t cusolver_dn_handle_; 325 cublasHandle_t cublas_handle_; 326 std::vector<TensorReference> scratch_tensor_refs_; 327 328 TF_DISALLOW_COPY_AND_ASSIGN(CudaSolver); 329 }; 330 331 // Helper class to allocate scratch memory and keep track of debug info. 332 // Mostly a thin wrapper around Tensor & allocate_temp. 333 template <typename Scalar> 334 class ScratchSpace { 335 public: 336 ScratchSpace(OpKernelContext* context, int64 size, bool on_host) 337 : ScratchSpace(context, TensorShape({size}), "", on_host) {} 338 339 ScratchSpace(OpKernelContext* context, int64 size, const string& debug_info, 340 bool on_host) 341 : ScratchSpace(context, TensorShape({size}), debug_info, on_host) {} 342 343 ScratchSpace(OpKernelContext* context, const TensorShape& shape, 344 const string& debug_info, bool on_host) 345 : context_(context), debug_info_(debug_info), on_host_(on_host) { 346 AllocatorAttributes alloc_attr; 347 if (on_host) { 348 // Allocate pinned memory on the host to avoid unnecessary 349 // synchronization. 350 alloc_attr.set_on_host(true); 351 alloc_attr.set_gpu_compatible(true); 352 } 353 TF_CHECK_OK(context->allocate_temp(DataTypeToEnum<Scalar>::value, shape, 354 &scratch_tensor_, alloc_attr)); 355 } 356 357 virtual ~ScratchSpace() {} 358 359 Scalar* mutable_data() { 360 return scratch_tensor_.template flat<Scalar>().data(); 361 } 362 const Scalar* data() const { 363 return scratch_tensor_.template flat<Scalar>().data(); 364 } 365 Scalar& operator()(int64 i) { 366 return scratch_tensor_.template flat<Scalar>()(i); 367 } 368 const Scalar& operator()(int64 i) const { 369 return scratch_tensor_.template flat<Scalar>()(i); 370 } 371 int64 bytes() const { return scratch_tensor_.TotalBytes(); } 372 int64 size() const { return scratch_tensor_.NumElements(); } 373 const string& debug_info() const { return debug_info_; } 374 375 Tensor& tensor() { return scratch_tensor_; } 376 const Tensor& tensor() const { return scratch_tensor_; } 377 378 // Returns true if this ScratchSpace is in host memory. 379 bool on_host() const { return on_host_; } 380 381 protected: 382 OpKernelContext* context() const { return context_; } 383 384 private: 385 OpKernelContext* context_; // not owned 386 const string debug_info_; 387 const bool on_host_; 388 Tensor scratch_tensor_; 389 }; 390 391 class HostLapackInfo : public ScratchSpace<int> { 392 public: 393 HostLapackInfo(OpKernelContext* context, int64 size, const string& debug_info) 394 : ScratchSpace<int>(context, size, debug_info, /* on_host */ true){}; 395 }; 396 397 class DeviceLapackInfo : public ScratchSpace<int> { 398 public: 399 DeviceLapackInfo(OpKernelContext* context, int64 size, 400 const string& debug_info) 401 : ScratchSpace<int>(context, size, debug_info, /* on_host */ false) {} 402 403 // Allocates a new scratch space on the host and launches a copy of the 404 // contents of *this to the new scratch space. Sets success to true if 405 // the copy kernel was launched successfully. 406 HostLapackInfo CopyToHost(bool* success) const { 407 CHECK(success != nullptr); 408 HostLapackInfo copy(context(), size(), debug_info()); 409 auto stream = context()->op_device_context()->stream(); 410 se::DeviceMemoryBase wrapped_src( 411 static_cast<void*>(const_cast<int*>(this->data()))); 412 *success = 413 stream->ThenMemcpy(copy.mutable_data(), wrapped_src, this->bytes()) 414 .ok(); 415 return copy; 416 } 417 }; 418 419 template <typename Scalar> 420 ScratchSpace<Scalar> CudaSolver::GetScratchSpace(const TensorShape& shape, 421 const string& debug_info, 422 bool on_host) { 423 ScratchSpace<Scalar> new_scratch_space(context_, shape, debug_info, on_host); 424 scratch_tensor_refs_.emplace_back(new_scratch_space.tensor()); 425 return std::move(new_scratch_space); 426 } 427 428 template <typename Scalar> 429 ScratchSpace<Scalar> CudaSolver::GetScratchSpace(int64 size, 430 const string& debug_info, 431 bool on_host) { 432 return GetScratchSpace<Scalar>(TensorShape({size}), debug_info, on_host); 433 } 434 435 inline DeviceLapackInfo CudaSolver::GetDeviceLapackInfo( 436 int64 size, const string& debug_info) { 437 DeviceLapackInfo new_dev_info(context_, size, debug_info); 438 scratch_tensor_refs_.emplace_back(new_dev_info.tensor()); 439 return new_dev_info; 440 } 441 442 } // namespace tensorflow 443 444 #endif // GOOGLE_CUDA 445 446 #endif // TENSORFLOW_CORE_KERNELS_CUDA_SOLVERS_H_ 447