• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3    Licensed under the Apache License, Version 2.0 (the "License");
4    you may not use this file except in compliance with the License.
5    You may obtain a copy of the License at
7    http://www.apache.org/licenses/LICENSE-2.0
9    Unless required by applicable law or agreed to in writing, software
10    distributed under the License is distributed on an "AS IS" BASIS,
11    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12    See the License for the specific language governing permissions and
13    limitations under the License.
14    ==============================================================================
15 */
16 #ifdef GOOGLE_CUDA
17 #include "tensorflow/core/kernels/cuda_solvers.h"
19 #include <chrono>
20 #include <complex>
21 #include <unordered_map>
22 #include <vector>
24 #include "third_party/gpus/cuda/include/cublas_v2.h"
25 #include "third_party/gpus/cuda/include/cusolverDn.h"
26 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/lib/core/blocking_counter.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/lib/core/stringpiece.h"
32 #include "tensorflow/core/lib/gtl/inlined_vector.h"
33 #include "tensorflow/core/platform/mutex.h"
34 #include "tensorflow/core/platform/stream_executor.h"
35 #include "tensorflow/core/platform/types.h"
36 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
38 // The CUDA cublas_api.h API contains const-correctness errors. Instead of
39 // casting away constness on our data, we instead reinterpret the CuBLAS
40 // functions as what they were clearly meant to be, and thus we can call
41 // the functions naturally.
42 //
43 // (The error is that input-only arrays are bound to parameter types
44 // "const T**" instead of the correct "const T* const*".)
45 extern "C" {
46 using getrs_S = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
47                                const float* const*, int, const int*, float**,
48                                int, int*, int);
49 using getrs_D = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
50                                const double* const*, int, const int*, double**,
51                                int, int*, int);
52 using getrs_C = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
53                                const float2* const*, int, const int*, float2**,
54                                int, int*, int);
55 using getrs_Z = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
56                                const double2* const*, int, const int*,
57                                double2**, int, int*, int);
59 using getri_S = cublasStatus_t(cublasContext*, int, const float* const*, int,
60                                const int*, float**, int, int*, int);
61 using getri_D = cublasStatus_t(cublasContext*, int, const double* const*, int,
62                                const int*, double**, int, int*, int);
63 using getri_C = cublasStatus_t(cublasContext*, int, const float2* const*, int,
64                                const int*, float2**, int, int*, int);
65 using getri_Z = cublasStatus_t(cublasContext*, int, const double2* const*, int,
66                                const int*, double2**, int, int*, int);
68 using matinv_S = cublasStatus_t(cublasContext*, int, const float* const*, int,
69                                 float**, int, int*, int);
70 using matinv_D = cublasStatus_t(cublasContext*, int, const double* const*, int,
71                                 double**, int, int*, int);
72 using matinv_C = cublasStatus_t(cublasContext*, int, const float2* const*, int,
73                                 float2**, int, int*, int);
74 using matinv_Z = cublasStatus_t(cublasContext*, int, const double2* const*, int,
75                                 double2**, int, int*, int);
76 }
78 namespace tensorflow {
79 namespace {
81 using se::cuda::ScopedActivateExecutorContext;
CopyHostToDevice(OpKernelContext * context,void * dst,const void * src,uint64 bytes)83 inline bool CopyHostToDevice(OpKernelContext* context, void* dst,
84                              const void* src, uint64 bytes) {
85   auto stream = context->op_device_context()->stream();
86   se::DeviceMemoryBase wrapped_dst(dst);
87   return stream->ThenMemcpy(&wrapped_dst, src, bytes).ok();
88 }
90 // A set of initialized handles to the underlying Cuda libraries used by
91 // CudaSolver. We maintain one such set of handles per unique stream.
92 struct CudaSolverHandles {
CudaSolverHandlestensorflow::__anona8fbe2f90111::CudaSolverHandles93   explicit CudaSolverHandles(cudaStream_t stream) {
94     CHECK(cusolverDnCreate(&cusolver_dn_handle) == CUSOLVER_STATUS_SUCCESS)
95         << "Failed to create cuSolverDN instance.";
96     CHECK(cusolverDnSetStream(cusolver_dn_handle, stream) ==
98         << "Failed to set cuSolverDN stream.";
99     CHECK(cublasCreate(&cublas_handle) == CUBLAS_STATUS_SUCCESS)
100         << "Failed to create cuBlas instance.";
101     CHECK(cublasSetStream(cublas_handle, stream) == CUBLAS_STATUS_SUCCESS)
102         << "Failed to set cuBlas stream.";
103   }
~CudaSolverHandlestensorflow::__anona8fbe2f90111::CudaSolverHandles105   ~CudaSolverHandles() {
106     CHECK(cublasDestroy(cublas_handle) == CUBLAS_STATUS_SUCCESS)
107         << "Failed to destroy cuBlas instance.";
108     CHECK(cusolverDnDestroy(cusolver_dn_handle) == CUSOLVER_STATUS_SUCCESS)
109         << "Failed to destroy cuSolverDN instance.";
110   }
111   cublasHandle_t cublas_handle;
112   cusolverDnHandle_t cusolver_dn_handle;
113 };
115 static mutex handle_map_mutex(LINKER_INITIALIZED);
117 using HandleMap =
118     std::unordered_map<cudaStream_t, std::unique_ptr<CudaSolverHandles>>;
120 // Returns a singleton map used for storing initialized handles for each unique
121 // cuda stream.
GetHandleMapSingleton()122 HandleMap* GetHandleMapSingleton() {
123   static HandleMap* cm = new HandleMap;
124   return cm;
125 }
127 }  // namespace
129 #define TF_RETURN_IF_CUSOLVER_ERROR(expr)                      \
130   do {                                                         \
131     auto status = (expr);                                      \
133       return errors::Internal(                                 \
134           __FILE__, ":", __LINE__,                             \
135           ": cuSolverDN call failed with status =", status);   \
136     }                                                          \
137   } while (0)
139 #define TF_RETURN_IF_CUBLAS_ERROR(expr)                                  \
140   do {                                                                   \
141     auto status = (expr);                                                \
142     if (TF_PREDICT_FALSE(status != CUBLAS_STATUS_SUCCESS)) {             \
143       return errors::Internal(__FILE__, ":", __LINE__,                   \
144                               ": cuBlas call failed status = ", status); \
145     }                                                                    \
146   } while (0)
CudaSolver(OpKernelContext * context)148 CudaSolver::CudaSolver(OpKernelContext* context) : context_(context) {
149   mutex_lock lock(handle_map_mutex);
150   const cudaStream_t* cu_stream_ptr = CHECK_NOTNULL(
151       reinterpret_cast<const cudaStream_t*>(context->op_device_context()
152                                                 ->stream()
153                                                 ->implementation()
154                                                 ->GpuStreamMemberHack()));
155   cuda_stream_ = *cu_stream_ptr;
156   HandleMap* handle_map = CHECK_NOTNULL(GetHandleMapSingleton());
157   auto it = handle_map->find(cuda_stream_);
158   if (it == handle_map->end()) {
159     LOG(INFO) << "Creating CudaSolver handles for stream " << cuda_stream_;
160     // Previously unseen Cuda stream. Initialize a set of Cuda solver library
161     // handles for it.
162     std::unique_ptr<CudaSolverHandles> new_handles(
163         new CudaSolverHandles(cuda_stream_));
164     it =
165         handle_map->insert(std::make_pair(cuda_stream_, std::move(new_handles)))
166             .first;
167   }
168   cusolver_dn_handle_ = it->second->cusolver_dn_handle;
169   cublas_handle_ = it->second->cublas_handle;
170 }
~CudaSolver()172 CudaSolver::~CudaSolver() {
173   for (auto tensor_ref : scratch_tensor_refs_) {
174     tensor_ref.Unref();
175   }
176 }
178 // static
CheckLapackInfoAndDeleteSolverAsync(std::unique_ptr<CudaSolver> solver,const std::vector<DeviceLapackInfo> & dev_lapack_infos,std::function<void (const Status &,const std::vector<HostLapackInfo> &)> info_checker_callback)179 void CudaSolver::CheckLapackInfoAndDeleteSolverAsync(
180     std::unique_ptr<CudaSolver> solver,
181     const std::vector<DeviceLapackInfo>& dev_lapack_infos,
182     std::function<void(const Status&, const std::vector<HostLapackInfo>&)>
183         info_checker_callback) {
184   CHECK(info_checker_callback != nullptr);
185   std::vector<HostLapackInfo> host_lapack_infos;
186   if (dev_lapack_infos.empty()) {
187     info_checker_callback(Status::OK(), host_lapack_infos);
188     return;
189   }
191   // Launch memcpys to copy info back from the device to the host.
192   for (const auto& dev_lapack_info : dev_lapack_infos) {
193     bool success = true;
194     auto host_copy = dev_lapack_info.CopyToHost(&success);
195     OP_REQUIRES(
196         solver->context(), success,
197         errors::Internal(
198             "Failed to launch copy of dev_lapack_info to host, debug_info = ",
199             dev_lapack_info.debug_info()));
200     host_lapack_infos.push_back(std::move(host_copy));
201   }
203   // This callback checks that all batch items in all calls were processed
204   // successfully and passes status to the info_checker_callback accordingly.
205   auto* stream = solver->context()->op_device_context()->stream();
206   auto wrapped_info_checker_callback =
207       [stream](
208           CudaSolver* solver,
209           std::function<void(const Status&, const std::vector<HostLapackInfo>&)>
210               info_checker_callback,
211           std::vector<HostLapackInfo> host_lapack_infos) {
212         ScopedActivateExecutorContext scoped_activation{stream->parent()};
213         Status status;
214         for (const auto& host_lapack_info : host_lapack_infos) {
215           for (int i = 0; i < host_lapack_info.size() && status.ok(); ++i) {
216             const int info_value = host_lapack_info(i);
217             if (info_value != 0) {
218               status = errors::InvalidArgument(
219                   "Got info = ", info_value, " for batch index ", i,
220                   ", expected info = 0. Debug_info = ",
221                   host_lapack_info.debug_info());
222             }
223           }
224           if (!status.ok()) {
225             break;
226           }
227         }
228         // Delete solver to release temp tensor refs.
229         delete solver;
231         // Delegate further error checking to provided functor.
232         info_checker_callback(status, host_lapack_infos);
233       };
234   // Note: An std::function cannot have unique_ptr arguments (it must be copy
235   // constructible and therefore so must its arguments). Therefore, we release
236   // solver into a raw pointer to be deleted at the end of
237   // wrapped_info_checker_callback.
238   // Release ownership of solver. It will be deleted in the cb callback.
239   auto solver_raw_ptr = solver.release();
240   auto cb =
241       std::bind(wrapped_info_checker_callback, solver_raw_ptr,
242                 std::move(info_checker_callback), std::move(host_lapack_infos));
244   solver_raw_ptr->context()
245       ->device()
246       ->tensorflow_gpu_device_info()
247       ->event_mgr->ThenExecute(stream, std::move(cb));
248 }
250 // static
CheckLapackInfoAndDeleteSolverAsync(std::unique_ptr<CudaSolver> solver,const std::vector<DeviceLapackInfo> & dev_lapack_info,AsyncOpKernel::DoneCallback done)251 void CudaSolver::CheckLapackInfoAndDeleteSolverAsync(
252     std::unique_ptr<CudaSolver> solver,
253     const std::vector<DeviceLapackInfo>& dev_lapack_info,
254     AsyncOpKernel::DoneCallback done) {
255   OpKernelContext* context = solver->context();
256   auto wrapped_done = [context, done](
257                           const Status& status,
258                           const std::vector<HostLapackInfo>& /* unused */) {
259     if (done != nullptr) {
260       OP_REQUIRES_OK_ASYNC(context, status, done);
261       done();
262     } else {
263       OP_REQUIRES_OK(context, status);
264     }
265   };
266   CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_lapack_info,
267                                       wrapped_done);
268 }
270 // Allocates a temporary tensor. The CudaSolver object maintains a
271 // TensorReference to the underlying Tensor to prevent it from being deallocated
272 // prematurely.
allocate_scoped_tensor(DataType type,const TensorShape & shape,Tensor * out_temp)273 Status CudaSolver::allocate_scoped_tensor(DataType type,
274                                           const TensorShape& shape,
275                                           Tensor* out_temp) {
276   const Status status = context_->allocate_temp(type, shape, out_temp);
277   if (status.ok()) {
278     scratch_tensor_refs_.emplace_back(*out_temp);
279   }
280   return status;
281 }
forward_input_or_allocate_scoped_tensor(gtl::ArraySlice<int> candidate_input_indices,DataType type,const TensorShape & shape,Tensor * out_temp)283 Status CudaSolver::forward_input_or_allocate_scoped_tensor(
284     gtl::ArraySlice<int> candidate_input_indices, DataType type,
285     const TensorShape& shape, Tensor* out_temp) {
286   const Status status = context_->forward_input_or_allocate_temp(
287       candidate_input_indices, type, shape, out_temp);
288   if (status.ok()) {
289     scratch_tensor_refs_.emplace_back(*out_temp);
290   }
291   return status;
292 }
294 // Macro that specializes a solver method for all 4 standard
295 // numeric types.
296 #define TF_CALL_LAPACK_TYPES(m) \
297   m(float, S) m(double, D) m(std::complex<float>, C) m(std::complex<double>, Z)
298 #define TF_CALL_LAPACK_TYPES_NO_COMPLEX(m) m(float, S) m(double, D)
300 // Macros to construct cusolverDn method names.
301 #define DN_SOLVER_FN(method, type_prefix) cusolverDn##type_prefix##method
302 #define DN_SOLVER_NAME(method, type_prefix) "cusolverDn" #type_prefix #method
303 #define DN_BUFSIZE_FN(method, type_prefix) \
304   cusolverDn##type_prefix##method##_bufferSize
306 // Macros to construct cublas method names.
307 #define BLAS_SOLVER_FN(method, type_prefix) cublas##type_prefix##method
308 #define BLAS_SOLVER_NAME(method, type_prefix) "cublas" #type_prefix #method
310 //=============================================================================
311 // Wrappers of cuSolverDN computational methods begin here.
312 //
313 // WARNING to implementers: The function signatures listed in the online docs
314 // are sometimes inaccurate, e.g., are missing 'const' on pointers
315 // to immutable arguments, while the actual headers have them as expected.
316 // Check the actual declarations in the cusolver_api.h header file.
317 //
318 // NOTE: The cuSolver functions called below appear not to be threadsafe.
319 // so we put a global lock around the calls. Since these functions only put a
320 // kernel on the shared stream, it is not a big performance hit.
321 // TODO(rmlarsen): Investigate if the locking is still needed in Cuda 9.
322 //=============================================================================
324 template <typename Scalar, typename SolverFnT>
GeamImpl(SolverFnT solver,cublasHandle_t cublas_handle,cublasOperation_t transa,cublasOperation_t transb,int m,int n,const Scalar * alpha,const Scalar * A,int lda,const Scalar * beta,const Scalar * B,int ldb,Scalar * C,int ldc)325 static inline Status GeamImpl(SolverFnT solver, cublasHandle_t cublas_handle,
326                               cublasOperation_t transa,
327                               cublasOperation_t transb, int m, int n,
328                               const Scalar* alpha, /* host or device pointer */
329                               const Scalar* A, int lda,
330                               const Scalar* beta, /* host or device pointer */
331                               const Scalar* B, int ldb, Scalar* C, int ldc) {
332   mutex_lock lock(handle_map_mutex);
333   using CudaScalar = typename CUDAComplexT<Scalar>::type;
334   TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, transa, transb, m, n,
335                                    reinterpret_cast<const CudaScalar*>(alpha),
336                                    reinterpret_cast<const CudaScalar*>(A), lda,
337                                    reinterpret_cast<const CudaScalar*>(beta),
338                                    reinterpret_cast<const CudaScalar*>(B), ldb,
339                                    reinterpret_cast<CudaScalar*>(C), ldc));
340   return Status::OK();
341 }
343 #define GEAM_INSTANCE(Scalar, type_prefix)                                     \
344   template <>                                                                  \
345   Status CudaSolver::Geam<Scalar>(                                             \
346       cublasOperation_t transa, cublasOperation_t transb, int m, int n,        \
347       const Scalar* alpha, /* host or device pointer */                        \
348       const Scalar* A, int lda,                                                \
349       const Scalar* beta, /* host or device pointer */                         \
350       const Scalar* B, int ldb, Scalar* C, int ldc) const {                    \
351     return GeamImpl(BLAS_SOLVER_FN(geam, type_prefix), cublas_handle_, transa, \
352                     transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc);        \
353   }
357 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
PotrfImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cublasFillMode_t uplo,int n,Scalar * A,int lda,int * dev_lapack_info)358 static inline Status PotrfImpl(BufSizeFnT bufsize, SolverFnT solver,
359                                CudaSolver* cuda_solver,
360                                OpKernelContext* context,
361                                cusolverDnHandle_t cusolver_dn_handle,
362                                cublasFillMode_t uplo, int n, Scalar* A, int lda,
363                                int* dev_lapack_info) {
364   mutex_lock lock(handle_map_mutex);
365   /* Get amount of workspace memory required. */
366   int lwork;
368       bufsize(cusolver_dn_handle, uplo, n, CUDAComplex(A), lda, &lwork));
369   /* Allocate device memory for workspace. */
370   auto dev_workspace =
371       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
372   /* Launch the solver kernel. */
374       cusolver_dn_handle, uplo, n, CUDAComplex(A), lda,
375       CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));
376   return Status::OK();
377 }
379 #define POTRF_INSTANCE(Scalar, type_prefix)                                  \
380   template <>                                                                \
381   Status CudaSolver::Potrf<Scalar>(cublasFillMode_t uplo, int n, Scalar* A,  \
382                                    int lda, int* dev_lapack_info) {          \
383     return PotrfImpl(DN_BUFSIZE_FN(potrf, type_prefix),                      \
384                      DN_SOLVER_FN(potrf, type_prefix), this, context_,       \
385                      cusolver_dn_handle_, uplo, n, A, lda, dev_lapack_info); \
386   }
390 #if CUDA_VERSION >= 9020
391 template <typename Scalar, typename SolverFnT>
PotrfBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cublasFillMode_t uplo,int n,const Scalar * const host_a_dev_ptrs[],int lda,DeviceLapackInfo * dev_lapack_info,int batch_size)392 static inline Status PotrfBatchedImpl(
393     SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
394     cusolverDnHandle_t cusolver_dn_handle, cublasFillMode_t uplo, int n,
395     const Scalar* const host_a_dev_ptrs[], int lda,
396     DeviceLapackInfo* dev_lapack_info, int batch_size) {
397   mutex_lock lock(handle_map_mutex);
398   using CudaScalar = typename CUDAComplexT<Scalar>::type;
399   ScratchSpace<uint8> dev_a_dev_ptrs =
400       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
401                                           /* on_host */ false);
402   if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
403                         host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
404     return errors::Internal("PotrfBatched: failed to copy pointers to device");
405   }
407       solver(cusolver_dn_handle, uplo, n,
408              reinterpret_cast<CudaScalar**>(dev_a_dev_ptrs.mutable_data()), lda,
409              dev_lapack_info->mutable_data(), batch_size));
410   return Status::OK();
411 }
413 #define POTRF_BATCHED_INSTANCE(Scalar, type_prefix)                        \
414   template <>                                                              \
415   Status CudaSolver::PotrfBatched(                                         \
416       cublasFillMode_t uplo, int n, const Scalar* const host_a_dev_ptrs[], \
417       int lda, DeviceLapackInfo* dev_lapack_info, int batch_size) {        \
418     return PotrfBatchedImpl(DN_SOLVER_FN(potrfBatched, type_prefix), this, \
419                             context_, cusolver_dn_handle_, uplo, n,        \
420                             host_a_dev_ptrs, lda, dev_lapack_info,         \
421                             batch_size);                                   \
422   }
425 #endif  // CUDA_VERSION >= 9020
427 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
GetrfImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,int m,int n,Scalar * A,int lda,int * dev_pivots,int * dev_lapack_info)428 static inline Status GetrfImpl(BufSizeFnT bufsize, SolverFnT solver,
429                                CudaSolver* cuda_solver,
430                                OpKernelContext* context,
431                                cusolverDnHandle_t cusolver_dn_handle, int m,
432                                int n, Scalar* A, int lda, int* dev_pivots,
433                                int* dev_lapack_info) {
434   mutex_lock lock(handle_map_mutex);
435   /* Get amount of workspace memory required. */
436   int lwork;
438       bufsize(cusolver_dn_handle, m, n, CUDAComplex(A), lda, &lwork));
439   /* Allocate device memory for workspace. */
440   auto dev_workspace =
441       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
442   /* Launch the solver kernel. */
444       cusolver_dn_handle, m, n, CUDAComplex(A), lda,
445       CUDAComplex(dev_workspace.mutable_data()), dev_pivots, dev_lapack_info));
446   return Status::OK();
447 }
449 #define GETRF_INSTANCE(Scalar, type_prefix)                                 \
450   template <>                                                               \
451   Status CudaSolver::Getrf<Scalar>(int m, int n, Scalar* A, int lda,        \
452                                    int* dev_pivots, int* dev_lapack_info) { \
453     return GetrfImpl(DN_BUFSIZE_FN(getrf, type_prefix),                     \
454                      DN_SOLVER_FN(getrf, type_prefix), this, context_,      \
455                      cusolver_dn_handle_, m, n, A, lda, dev_pivots,         \
456                      dev_lapack_info);                                      \
457   }
461 template <typename Scalar, typename SolverFnT>
GetrsImpl(SolverFnT solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cublasOperation_t trans,int n,int nrhs,const Scalar * A,int lda,const int * pivots,Scalar * B,int ldb,int * dev_lapack_info)462 static inline Status GetrsImpl(SolverFnT solver, OpKernelContext* context,
463                                cusolverDnHandle_t cusolver_dn_handle,
464                                cublasOperation_t trans, int n, int nrhs,
465                                const Scalar* A, int lda, const int* pivots,
466                                Scalar* B, int ldb, int* dev_lapack_info) {
467   mutex_lock lock(handle_map_mutex);
468   /* Launch the solver kernel. */
469   TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, trans, n, nrhs,
470                                      CUDAComplex(A), lda, pivots,
471                                      CUDAComplex(B), ldb, dev_lapack_info));
472   return Status::OK();
473 }
475 #define GETRS_INSTANCE(Scalar, type_prefix)                                  \
476   template <>                                                                \
477   Status CudaSolver::Getrs<Scalar>(                                          \
478       cublasOperation_t trans, int n, int nrhs, const Scalar* A, int lda,    \
479       const int* pivots, Scalar* B, int ldb, int* dev_lapack_info) const {   \
480     return GetrsImpl(DN_SOLVER_FN(getrs, type_prefix), context_,             \
481                      cusolver_dn_handle_, trans, n, nrhs, A, lda, pivots, B, \
482                      ldb, dev_lapack_info);                                  \
483   }
487 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
GeqrfImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,int m,int n,Scalar * A,int lda,Scalar * tau,int * dev_lapack_info)488 static inline Status GeqrfImpl(BufSizeFnT bufsize, SolverFnT solver,
489                                CudaSolver* cuda_solver,
490                                OpKernelContext* context,
491                                cusolverDnHandle_t cusolver_dn_handle, int m,
492                                int n, Scalar* A, int lda, Scalar* tau,
493                                int* dev_lapack_info) {
494   mutex_lock lock(handle_map_mutex);
495   /* Get amount of workspace memory required. */
496   int lwork;
498       bufsize(cusolver_dn_handle, m, n, CUDAComplex(A), lda, &lwork));
499   /* Allocate device memory for workspace. */
500   auto dev_workspace =
501       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
502   /* Launch the solver kernel. */
504       cusolver_dn_handle, m, n, CUDAComplex(A), lda, CUDAComplex(tau),
505       CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));
506   return Status::OK();
507 }
509 #define GEQRF_INSTANCE(Scalar, type_prefix)                                    \
510   template <>                                                                  \
511   Status CudaSolver::Geqrf<Scalar>(int m, int n, Scalar* A, int lda,           \
512                                    Scalar* tau, int* dev_lapack_info) {        \
513     return GeqrfImpl(DN_BUFSIZE_FN(geqrf, type_prefix),                        \
514                      DN_SOLVER_FN(geqrf, type_prefix), this, context_,         \
515                      cusolver_dn_handle_, m, n, A, lda, tau, dev_lapack_info); \
516   }
520 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
UnmqrImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cublasSideMode_t side,cublasOperation_t trans,int m,int n,int k,const Scalar * dev_a,int lda,const Scalar * dev_tau,Scalar * dev_c,int ldc,int * dev_lapack_info)521 static inline Status UnmqrImpl(BufSizeFnT bufsize, SolverFnT solver,
522                                CudaSolver* cuda_solver,
523                                OpKernelContext* context,
524                                cusolverDnHandle_t cusolver_dn_handle,
525                                cublasSideMode_t side, cublasOperation_t trans,
526                                int m, int n, int k, const Scalar* dev_a,
527                                int lda, const Scalar* dev_tau, Scalar* dev_c,
528                                int ldc, int* dev_lapack_info) {
529   mutex_lock lock(handle_map_mutex);
530   /* Get amount of workspace memory required. */
531   int lwork;
533       bufsize(cusolver_dn_handle, side, trans, m, n, k, CUDAComplex(dev_a), lda,
534               CUDAComplex(dev_tau), CUDAComplex(dev_c), ldc, &lwork));
535   /* Allocate device memory for workspace. */
536   auto dev_workspace =
537       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
538   /* Launch the solver kernel. */
540       cusolver_dn_handle, side, trans, m, n, k, CUDAComplex(dev_a), lda,
541       CUDAComplex(dev_tau), CUDAComplex(dev_c), ldc,
542       CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));
543   return Status::OK();
544 }
546 // Unfortunately the LAPACK function name differs for the real and complex case
547 // (complex ones are prefixed with "UN" for "unitary"), so we instantiate each
548 // one separately.
549 #define UNMQR_INSTANCE(Scalar, function_prefix, type_prefix)                  \
550   template <>                                                                 \
551   Status CudaSolver::Unmqr(cublasSideMode_t side, cublasOperation_t trans,    \
552                            int m, int n, int k, const Scalar* dev_a, int lda, \
553                            const Scalar* dev_tau, Scalar* dev_c, int ldc,     \
554                            int* dev_lapack_info) {                            \
555     return UnmqrImpl(DN_BUFSIZE_FN(function_prefix##mqr, type_prefix),        \
556                      DN_SOLVER_FN(function_prefix##mqr, type_prefix), this,   \
557                      context_, cusolver_dn_handle_, side, trans, m, n, k,     \
558                      dev_a, lda, dev_tau, dev_c, ldc, dev_lapack_info);       \
559   }
561 UNMQR_INSTANCE(float, or, S);
562 UNMQR_INSTANCE(double, or, D);
563 UNMQR_INSTANCE(complex64, un, C);
564 UNMQR_INSTANCE(complex128, un, Z);
566 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
UngqrImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,int m,int n,int k,Scalar * dev_a,int lda,const Scalar * dev_tau,int * dev_lapack_info)567 static inline Status UngqrImpl(BufSizeFnT bufsize, SolverFnT solver,
568                                CudaSolver* cuda_solver,
569                                OpKernelContext* context,
570                                cusolverDnHandle_t cusolver_dn_handle, int m,
571                                int n, int k, Scalar* dev_a, int lda,
572                                const Scalar* dev_tau, int* dev_lapack_info) {
573   mutex_lock lock(handle_map_mutex);
574   /* Get amount of workspace memory required. */
575   int lwork;
576   TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, k,
577                                       CUDAComplex(dev_a), lda,
578                                       CUDAComplex(dev_tau), &lwork));
579   /* Allocate device memory for workspace. */
580   auto dev_workspace =
581       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
582   /* Launch the solver kernel. */
584       solver(cusolver_dn_handle, m, n, k, CUDAComplex(dev_a), lda,
585              CUDAComplex(dev_tau), CUDAComplex(dev_workspace.mutable_data()),
586              lwork, dev_lapack_info));
587   return Status::OK();
588 }
590 #define UNGQR_INSTANCE(Scalar, function_prefix, type_prefix)                \
591   template <>                                                               \
592   Status CudaSolver::Ungqr(int m, int n, int k, Scalar* dev_a, int lda,     \
593                            const Scalar* dev_tau, int* dev_lapack_info) {   \
594     return UngqrImpl(DN_BUFSIZE_FN(function_prefix##gqr, type_prefix),      \
595                      DN_SOLVER_FN(function_prefix##gqr, type_prefix), this, \
596                      context_, cusolver_dn_handle_, m, n, k, dev_a, lda,    \
597                      dev_tau, dev_lapack_info);                             \
598   }
600 UNGQR_INSTANCE(float, or, S);
601 UNGQR_INSTANCE(double, or, D);
602 UNGQR_INSTANCE(complex64, un, C);
603 UNGQR_INSTANCE(complex128, un, Z);
605 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
HeevdImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,Scalar * dev_A,int lda,typename Eigen::NumTraits<Scalar>::Real * dev_W,int * dev_lapack_info)606 static inline Status HeevdImpl(BufSizeFnT bufsize, SolverFnT solver,
607                                CudaSolver* cuda_solver,
608                                OpKernelContext* context,
609                                cusolverDnHandle_t cusolver_dn_handle,
610                                cusolverEigMode_t jobz, cublasFillMode_t uplo,
611                                int n, Scalar* dev_A, int lda,
612                                typename Eigen::NumTraits<Scalar>::Real* dev_W,
613                                int* dev_lapack_info) {
614   mutex_lock lock(handle_map_mutex);
615   /* Get amount of workspace memory required. */
616   int lwork;
617   TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, jobz, uplo, n,
618                                       CUDAComplex(dev_A), lda,
619                                       CUDAComplex(dev_W), &lwork));
620   /* Allocate device memory for workspace. */
621   auto dev_workspace =
622       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
623   /* Launch the solver kernel. */
625       solver(cusolver_dn_handle, jobz, uplo, n, CUDAComplex(dev_A), lda,
626              CUDAComplex(dev_W), CUDAComplex(dev_workspace.mutable_data()),
627              lwork, dev_lapack_info));
628   return Status::OK();
629 }
631 #define HEEVD_INSTANCE(Scalar, function_prefix, type_prefix)                   \
632   template <>                                                                  \
633   Status CudaSolver::Heevd(cusolverEigMode_t jobz, cublasFillMode_t uplo,      \
634                            int n, Scalar* dev_A, int lda,                      \
635                            typename Eigen::NumTraits<Scalar>::Real* dev_W,     \
636                            int* dev_lapack_info) {                             \
637     return HeevdImpl(DN_BUFSIZE_FN(function_prefix##evd, type_prefix),         \
638                      DN_SOLVER_FN(function_prefix##evd, type_prefix), this,    \
639                      context_, cusolver_dn_handle_, jobz, uplo, n, dev_A, lda, \
640                      dev_W, dev_lapack_info);                                  \
641   }
643 HEEVD_INSTANCE(float, sy, S);
644 HEEVD_INSTANCE(double, sy, D);
645 HEEVD_INSTANCE(complex64, he, C);
646 HEEVD_INSTANCE(complex128, he, Z);
648 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
GesvdImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,signed char jobu,signed char jobvt,int m,int n,Scalar * A,int lda,Scalar * S,Scalar * U,int ldu,Scalar * VT,int ldvt,int * dev_lapack_info)649 static inline Status GesvdImpl(
650     BufSizeFnT bufsize, SolverFnT solver, CudaSolver* cuda_solver,
651     OpKernelContext* context, cusolverDnHandle_t cusolver_dn_handle,
652     signed char jobu, signed char jobvt, int m, int n, Scalar* A, int lda,
653     Scalar* S, Scalar* U, int ldu, Scalar* VT, int ldvt, int* dev_lapack_info) {
654   mutex_lock lock(handle_map_mutex);
655   /* Get amount of workspace memory required. */
656   int lwork;
657   TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, &lwork));
658   /* Allocate device memory for workspace. */
659   auto dev_workspace =
660       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
661   TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, jobu, jobvt, m, n,
662                                      CUDAComplex(A), lda, S, CUDAComplex(U),
663                                      ldu, CUDAComplex(VT), ldvt,
664                                      CUDAComplex(dev_workspace.mutable_data()),
665                                      lwork, nullptr, dev_lapack_info));
666   return Status::OK();
667 }
669 #define GESVD_INSTANCE(Scalar, type_prefix)                              \
670   template <>                                                            \
671   Status CudaSolver::Gesvd<Scalar>(                                      \
672       signed char jobu, signed char jobvt, int m, int n, Scalar* dev_A,  \
673       int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT,    \
674       int ldvt, int* dev_lapack_info) {                                  \
675     return GesvdImpl(DN_BUFSIZE_FN(gesvd, type_prefix),                  \
676                      DN_SOLVER_FN(gesvd, type_prefix), this, context_,   \
677                      cusolver_dn_handle_, jobu, jobvt, m, n, dev_A, lda, \
678                      dev_S, dev_U, ldu, dev_VT, ldvt, dev_lapack_info);  \
679   }
683 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
GesvdjBatchedImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cusolverEigMode_t jobz,int m,int n,Scalar * A,int lda,Scalar * S,Scalar * U,int ldu,Scalar * V,int ldv,int * dev_lapack_info,int batch_size)684 static inline Status GesvdjBatchedImpl(BufSizeFnT bufsize, SolverFnT solver,
685                                        CudaSolver* cuda_solver,
686                                        OpKernelContext* context,
687                                        cusolverDnHandle_t cusolver_dn_handle,
688                                        cusolverEigMode_t jobz, int m, int n,
689                                        Scalar* A, int lda, Scalar* S, Scalar* U,
690                                        int ldu, Scalar* V, int ldv,
691                                        int* dev_lapack_info, int batch_size) {
692   mutex_lock lock(handle_map_mutex);
693   /* Get amount of workspace memory required. */
694   int lwork;
695   /* Default parameters for gesvdj and gesvdjBatched. */
696   gesvdjInfo_t svdj_info;
697   TF_RETURN_IF_CUSOLVER_ERROR(cusolverDnCreateGesvdjInfo(&svdj_info));
699       cusolver_dn_handle, jobz, m, n, CUDAComplex(A), lda, S, CUDAComplex(U),
700       ldu, CUDAComplex(V), ldv, &lwork, svdj_info, batch_size));
701   /* Allocate device memory for workspace. */
702   auto dev_workspace =
703       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
705       cusolver_dn_handle, jobz, m, n, CUDAComplex(A), lda, S, CUDAComplex(U),
706       ldu, CUDAComplex(V), ldv, CUDAComplex(dev_workspace.mutable_data()),
707       lwork, dev_lapack_info, svdj_info, batch_size));
708   TF_RETURN_IF_CUSOLVER_ERROR(cusolverDnDestroyGesvdjInfo(svdj_info));
709   return Status::OK();
710 }
712 #define GESVDJBATCHED_INSTANCE(Scalar, type_prefix)                            \
713   template <>                                                                  \
714   Status CudaSolver::GesvdjBatched<Scalar>(                                    \
715       cusolverEigMode_t jobz, int m, int n, Scalar* dev_A, int lda,            \
716       Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_V, int ldv,           \
717       int* dev_lapack_info, int batch_size) {                                  \
718     return GesvdjBatchedImpl(DN_BUFSIZE_FN(gesvdjBatched, type_prefix),        \
719                              DN_SOLVER_FN(gesvdjBatched, type_prefix), this,   \
720                              context_, cusolver_dn_handle_, jobz, m, n, dev_A, \
721                              lda, dev_S, dev_U, ldu, dev_V, ldv,               \
722                              dev_lapack_info, batch_size);                     \
723   }
727 //=============================================================================
728 // Wrappers of cuBlas computational methods begin here.
729 //
730 // WARNING to implementers: The function signatures listed in the online docs
731 // are sometimes inaccurate, e.g., are missing 'const' on pointers
732 // to immutable arguments, while the actual headers have them as expected.
733 // Check the actual declarations in the cublas_api.h header file.
734 //=============================================================================
735 template <typename Scalar, typename SolverFnT>
GetrfBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,int n,const Scalar * const host_a_dev_ptrs[],int lda,int * dev_pivots,DeviceLapackInfo * dev_lapack_info,int batch_size)736 static inline Status GetrfBatchedImpl(SolverFnT solver, CudaSolver* cuda_solver,
737                                       OpKernelContext* context,
738                                       cublasHandle_t cublas_handle, int n,
739                                       const Scalar* const host_a_dev_ptrs[],
740                                       int lda, int* dev_pivots,
741                                       DeviceLapackInfo* dev_lapack_info,
742                                       int batch_size) {
743   mutex_lock lock(handle_map_mutex);
744   using CudaScalar = typename CUDAComplexT<Scalar>::type;
745   ScratchSpace<uint8> dev_a_dev_ptrs =
746       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
747                                           /* on_host */ false);
748   if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
749                         host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
750     return errors::Internal("GetrfBatched: failed to copy pointers to device");
751   }
753       solver(cublas_handle, n,
754              reinterpret_cast<CudaScalar**>(dev_a_dev_ptrs.mutable_data()), lda,
755              dev_pivots, dev_lapack_info->mutable_data(), batch_size));
756   return Status::OK();
757 }
759 #define GETRF_BATCHED_INSTANCE(Scalar, type_prefix)                            \
760   template <>                                                                  \
761   Status CudaSolver::GetrfBatched(                                             \
762       int n, const Scalar* const host_a_dev_ptrs[], int lda, int* dev_pivots,  \
763       DeviceLapackInfo* dev_lapack_info, int batch_size) {                     \
764     return GetrfBatchedImpl(BLAS_SOLVER_FN(getrfBatched, type_prefix), this,   \
765                             context_, cublas_handle_, n, host_a_dev_ptrs, lda, \
766                             dev_pivots, dev_lapack_info, batch_size);          \
767   }
771 template <typename Scalar, typename SolverFnT>
GetrsBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,cublasOperation_t trans,int n,int nrhs,const Scalar * const host_a_dev_ptrs[],int lda,const int * dev_pivots,const Scalar * const host_b_dev_ptrs[],int ldb,int * host_lapack_info,int batch_size)772 static inline Status GetrsBatchedImpl(
773     SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
774     cublasHandle_t cublas_handle, cublasOperation_t trans, int n, int nrhs,
775     const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots,
776     const Scalar* const host_b_dev_ptrs[], int ldb, int* host_lapack_info,
777     int batch_size) {
778   mutex_lock lock(handle_map_mutex);
779   using CudaScalar = typename CUDAComplexT<Scalar>::type;
780   ScratchSpace<uint8> dev_a_dev_ptrs =
781       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
782                                           /* on_host */ false);
783   ScratchSpace<uint8> dev_b_dev_ptrs =
784       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
785                                           /* on_host */ false);
786   if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
787                         host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
788     return errors::Internal("GetrsBatched: failed to copy pointers to device");
789   }
790   if (!CopyHostToDevice(context, dev_b_dev_ptrs.mutable_data() /* dest */,
791                         host_b_dev_ptrs /* source */, dev_b_dev_ptrs.bytes())) {
792     return errors::Internal("GetrsBatched: failed to copy pointers to device");
793   }
795       cublas_handle, trans, n, nrhs,
796       reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()), lda,
797       dev_pivots, reinterpret_cast<CudaScalar**>(dev_b_dev_ptrs.mutable_data()),
798       ldb, host_lapack_info, batch_size));
799   return Status::OK();
800 }
802 #define GETRS_BATCHED_INSTANCE(Scalar, type_prefix)                            \
803   template <>                                                                  \
804   Status CudaSolver::GetrsBatched(                                             \
805       cublasOperation_t trans, int n, int nrhs,                                \
806       const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots,   \
807       const Scalar* const host_b_dev_ptrs[], int ldb, int* host_lapack_info,   \
808       int batch_size) {                                                        \
809     return GetrsBatchedImpl(reinterpret_cast<getrs_##type_prefix*>(            \
810                                 BLAS_SOLVER_FN(getrsBatched, type_prefix)),    \
811                             this, context_, cublas_handle_, trans, n, nrhs,    \
812                             host_a_dev_ptrs, lda, dev_pivots, host_b_dev_ptrs, \
813                             ldb, host_lapack_info, batch_size);                \
814   }
818 template <typename Scalar, typename SolverFnT>
GetriBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,int n,const Scalar * const host_a_dev_ptrs[],int lda,const int * dev_pivots,const Scalar * const host_a_inv_dev_ptrs[],int ldainv,DeviceLapackInfo * dev_lapack_info,int batch_size)819 static inline Status GetriBatchedImpl(
820     SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
821     cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[],
822     int lda, const int* dev_pivots, const Scalar* const host_a_inv_dev_ptrs[],
823     int ldainv, DeviceLapackInfo* dev_lapack_info, int batch_size) {
824   mutex_lock lock(handle_map_mutex);
825   using CudaScalar = typename CUDAComplexT<Scalar>::type;
826   ScratchSpace<uint8> dev_a_dev_ptrs =
827       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
828                                           /* on_host */ false);
829   ScratchSpace<uint8> dev_a_inv_dev_ptrs = cuda_solver->GetScratchSpace<uint8>(
830       sizeof(CudaScalar*) * batch_size, "", /* on_host */ false);
831   if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
832                         host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes()) ||
833       !CopyHostToDevice(context, dev_a_inv_dev_ptrs.mutable_data(),
834                         host_a_inv_dev_ptrs, dev_a_inv_dev_ptrs.bytes())) {
835     return errors::Internal("GetriBatched: failed to copy pointers to device");
836   }
838       solver(cublas_handle, n,
839              reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()),
840              lda, dev_pivots,
841              reinterpret_cast<CudaScalar**>(dev_a_inv_dev_ptrs.mutable_data()),
842              ldainv, dev_lapack_info->mutable_data(), batch_size));
843   return Status::OK();
844 }
846 #define GETRI_BATCHED_INSTANCE(Scalar, type_prefix)                          \
847   template <>                                                                \
848   Status CudaSolver::GetriBatched(                                           \
849       int n, const Scalar* const host_a_dev_ptrs[], int lda,                 \
850       const int* dev_pivots, const Scalar* const host_a_inv_dev_ptrs[],      \
851       int ldainv, DeviceLapackInfo* dev_lapack_info, int batch_size) {       \
852     return GetriBatchedImpl(                                                 \
853         reinterpret_cast<getri_##type_prefix*>(                              \
854             BLAS_SOLVER_FN(getriBatched, type_prefix)),                      \
855         this, context_, cublas_handle_, n, host_a_dev_ptrs, lda, dev_pivots, \
856         host_a_inv_dev_ptrs, ldainv, dev_lapack_info, batch_size);           \
857   }
861 template <typename Scalar, typename SolverFnT>
MatInvBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,int n,const Scalar * const host_a_dev_ptrs[],int lda,const Scalar * const host_a_inv_dev_ptrs[],int ldainv,DeviceLapackInfo * dev_lapack_info,int batch_size)862 static inline Status MatInvBatchedImpl(
863     SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
864     cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[],
865     int lda, const Scalar* const host_a_inv_dev_ptrs[], int ldainv,
866     DeviceLapackInfo* dev_lapack_info, int batch_size) {
867   mutex_lock lock(handle_map_mutex);
868   using CudaScalar = typename CUDAComplexT<Scalar>::type;
869   ScratchSpace<uint8> dev_a_dev_ptrs =
870       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
871                                           /* on_host */ false);
872   ScratchSpace<uint8> dev_a_inv_dev_ptrs = cuda_solver->GetScratchSpace<uint8>(
873       sizeof(CudaScalar*) * batch_size, "", /* on_host */ false);
874   if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
875                         host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes()) ||
876       !CopyHostToDevice(context, dev_a_inv_dev_ptrs.mutable_data(),
877                         host_a_inv_dev_ptrs, dev_a_inv_dev_ptrs.bytes())) {
878     return errors::Internal("MatInvBatched: failed to copy pointers to device");
879   }
881       cublas_handle, n,
882       reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()), lda,
883       reinterpret_cast<CudaScalar**>(dev_a_inv_dev_ptrs.mutable_data()), ldainv,
884       dev_lapack_info->mutable_data(), batch_size));
885   return Status::OK();
886 }
888 #define MATINV_BATCHED_INSTANCE(Scalar, type_prefix)                          \
889   template <>                                                                 \
890   Status CudaSolver::MatInvBatched(                                           \
891       int n, const Scalar* const host_a_dev_ptrs[], int lda,                  \
892       const Scalar* const host_a_inv_dev_ptrs[], int ldainv,                  \
893       DeviceLapackInfo* dev_lapack_info, int batch_size) {                    \
894     return MatInvBatchedImpl(reinterpret_cast<matinv_##type_prefix*>(         \
895                                  BLAS_SOLVER_FN(matinvBatched, type_prefix)), \
896                              this, context_, cublas_handle_, n,               \
897                              host_a_dev_ptrs, lda, host_a_inv_dev_ptrs,       \
898                              ldainv, dev_lapack_info, batch_size);            \
899   }
903 template <typename Scalar, typename SolverFnT>
TrsmImpl(SolverFnT solver,cublasHandle_t cublas_handle,cublasSideMode_t side,cublasFillMode_t uplo,cublasOperation_t trans,cublasDiagType_t diag,int m,int n,const Scalar * alpha,const Scalar * A,int lda,Scalar * B,int ldb)904 static inline Status TrsmImpl(SolverFnT solver, cublasHandle_t cublas_handle,
905                               cublasSideMode_t side, cublasFillMode_t uplo,
906                               cublasOperation_t trans, cublasDiagType_t diag,
907                               int m, int n,
908                               const Scalar* alpha, /* host or device pointer */
909                               const Scalar* A, int lda, Scalar* B, int ldb) {
910   mutex_lock lock(handle_map_mutex);
911   using CudaScalar = typename CUDAComplexT<Scalar>::type;
912   TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, side, uplo, trans, diag, m, n,
913                                    reinterpret_cast<const CudaScalar*>(alpha),
914                                    reinterpret_cast<const CudaScalar*>(A), lda,
915                                    reinterpret_cast<CudaScalar*>(B), ldb));
916   return Status::OK();
917 }
919 #define TRSM_INSTANCE(Scalar, type_prefix)                                   \
920   template <>                                                                \
921   Status CudaSolver::Trsm<Scalar>(                                           \
922       cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, \
923       cublasDiagType_t diag, int m, int n,                                   \
924       const Scalar* alpha, /* host or device pointer */                      \
925       const Scalar* A, int lda, Scalar* B, int ldb) {                        \
926     return TrsmImpl(BLAS_SOLVER_FN(trsm, type_prefix), cublas_handle_, side, \
927                     uplo, trans, diag, m, n, alpha, A, lda, B, ldb);         \
928   }
932 template <typename Scalar, typename SolverFnT>
TrsvImpl(SolverFnT solver,cublasHandle_t cublas_handle,cublasFillMode_t uplo,cublasOperation_t trans,cublasDiagType_t diag,int n,const Scalar * A,int lda,Scalar * x,int incx)933 static inline Status TrsvImpl(SolverFnT solver, cublasHandle_t cublas_handle,
934                               cublasFillMode_t uplo, cublasOperation_t trans,
935                               cublasDiagType_t diag, int n, const Scalar* A,
936                               int lda, Scalar* x, int incx) {
937   mutex_lock lock(handle_map_mutex);
938   using CudaScalar = typename CUDAComplexT<Scalar>::type;
939   TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, uplo, trans, diag, n,
940                                    reinterpret_cast<const CudaScalar*>(A), lda,
941                                    reinterpret_cast<CudaScalar*>(x), incx));
942   return Status::OK();
943 }
945 #define TRSV_INSTANCE(Scalar, type_prefix)                                   \
946   template <>                                                                \
947   Status CudaSolver::Trsv<Scalar>(                                           \
948       cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, \
949       int n, const Scalar* A, int lda, Scalar* x, int incx) {                \
950     return TrsvImpl(BLAS_SOLVER_FN(trsv, type_prefix), cublas_handle_, uplo, \
951                     trans, diag, n, A, lda, x, incx);                        \
952   }
956 template <typename Scalar, typename SolverFnT>
TrsmBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,cublasSideMode_t side,cublasFillMode_t uplo,cublasOperation_t trans,cublasDiagType_t diag,int m,int n,const Scalar * alpha,const Scalar * const host_a_dev_ptrs[],int lda,Scalar * host_b_dev_ptrs[],int ldb,int batch_size)957 static inline Status TrsmBatchedImpl(
958     SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
959     cublasHandle_t cublas_handle, cublasSideMode_t side, cublasFillMode_t uplo,
960     cublasOperation_t trans, cublasDiagType_t diag, int m, int n,
961     const Scalar* alpha, const Scalar* const host_a_dev_ptrs[], int lda,
962     Scalar* host_b_dev_ptrs[], int ldb, int batch_size) {
963   mutex_lock lock(handle_map_mutex);
964   using CudaScalar = typename CUDAComplexT<Scalar>::type;
965   ScratchSpace<uint8> dev_a_dev_ptrs =
966       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
967                                           /* on_host */ false);
968   ScratchSpace<uint8> dev_b_dev_ptrs =
969       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
970                                           /* on_host */ false);
971   if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
972                         host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
973     return errors::Internal("TrsmBatched: failed to copy pointers to device");
974   }
975   if (!CopyHostToDevice(context, dev_b_dev_ptrs.mutable_data() /* dest */,
976                         host_b_dev_ptrs /* source */, dev_b_dev_ptrs.bytes())) {
977     return errors::Internal("TrsmBatched: failed to copy pointers to device");
978   }
980       solver(cublas_handle, side, uplo, trans, diag, m, n,
981              reinterpret_cast<const CudaScalar*>(alpha),
982              reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()),
983              lda, reinterpret_cast<CudaScalar**>(dev_b_dev_ptrs.mutable_data()),
984              ldb, batch_size));
985   return Status::OK();
986 }
988 #define TRSM_BATCHED_INSTANCE(Scalar, type_prefix)                            \
989   template <>                                                                 \
990   Status CudaSolver::TrsmBatched(                                             \
991       cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans,  \
992       cublasDiagType_t diag, int m, int n, const Scalar* alpha,               \
993       const Scalar* const dev_Aarray[], int lda, Scalar* dev_Barray[],        \
994       int ldb, int batch_size) {                                              \
995     return TrsmBatchedImpl(BLAS_SOLVER_FN(trsmBatched, type_prefix), this,    \
996                            context_, cublas_handle_, side, uplo, trans, diag, \
997                            m, n, alpha, dev_Aarray, lda, dev_Barray, ldb,     \
998                            batch_size);                                       \
999   }
1003 }  // namespace tensorflow
1005 #endif  // GOOGLE_CUDA