• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3    Licensed under the Apache License, Version 2.0 (the "License");
4    you may not use this file except in compliance with the License.
5    You may obtain a copy of the License at
6 
7    http://www.apache.org/licenses/LICENSE-2.0
8 
9    Unless required by applicable law or agreed to in writing, software
10    distributed under the License is distributed on an "AS IS" BASIS,
11    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12    See the License for the specific language governing permissions and
13    limitations under the License.
14    ==============================================================================
15 */
16 #ifdef GOOGLE_CUDA
17 #include "tensorflow/core/util/cuda_solvers.h"
18 
19 #include <chrono>
20 #include <complex>
21 #include <unordered_map>
22 #include <vector>
23 
24 #include "third_party/gpus/cuda/include/cublas_v2.h"
25 #include "third_party/gpus/cuda/include/cusolverDn.h"
26 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/lib/core/blocking_counter.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/lib/core/stringpiece.h"
32 #include "tensorflow/core/lib/gtl/inlined_vector.h"
33 #include "tensorflow/core/platform/mutex.h"
34 #include "tensorflow/core/platform/stream_executor.h"
35 #include "tensorflow/core/platform/types.h"
36 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
37 
38 // The CUDA cublas_api.h API contains const-correctness errors. Instead of
39 // casting away constness on our data, we instead reinterpret the CuBLAS
40 // functions as what they were clearly meant to be, and thus we can call
41 // the functions naturally.
42 //
43 // (The error is that input-only arrays are bound to parameter types
44 // "const T**" instead of the correct "const T* const*".)
45 extern "C" {
46 using getrs_S = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
47                                const float* const*, int, const int*, float**,
48                                int, int*, int);
49 using getrs_D = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
50                                const double* const*, int, const int*, double**,
51                                int, int*, int);
52 using getrs_C = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
53                                const float2* const*, int, const int*, float2**,
54                                int, int*, int);
55 using getrs_Z = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
56                                const double2* const*, int, const int*,
57                                double2**, int, int*, int);
58 
59 using getri_S = cublasStatus_t(cublasContext*, int, const float* const*, int,
60                                const int*, float**, int, int*, int);
61 using getri_D = cublasStatus_t(cublasContext*, int, const double* const*, int,
62                                const int*, double**, int, int*, int);
63 using getri_C = cublasStatus_t(cublasContext*, int, const float2* const*, int,
64                                const int*, float2**, int, int*, int);
65 using getri_Z = cublasStatus_t(cublasContext*, int, const double2* const*, int,
66                                const int*, double2**, int, int*, int);
67 
68 using matinv_S = cublasStatus_t(cublasContext*, int, const float* const*, int,
69                                 float**, int, int*, int);
70 using matinv_D = cublasStatus_t(cublasContext*, int, const double* const*, int,
71                                 double**, int, int*, int);
72 using matinv_C = cublasStatus_t(cublasContext*, int, const float2* const*, int,
73                                 float2**, int, int*, int);
74 using matinv_Z = cublasStatus_t(cublasContext*, int, const double2* const*, int,
75                                 double2**, int, int*, int);
76 
77 using trsm_S = cublasStatus_t(cublasContext*, cublasSideMode_t,
78                               cublasFillMode_t, cublasOperation_t,
79                               cublasDiagType_t, int, int, const float*,
80                               const float* const*, int, float* const*, int,
81                               int);
82 using trsm_D = cublasStatus_t(cublasContext*, cublasSideMode_t,
83                               cublasFillMode_t, cublasOperation_t,
84                               cublasDiagType_t, int, int, const double*,
85                               const double* const*, int, double* const*, int,
86                               int);
87 using trsm_C = cublasStatus_t(cublasContext*, cublasSideMode_t,
88                               cublasFillMode_t, cublasOperation_t,
89                               cublasDiagType_t, int, int, const float2*,
90                               const float2* const*, int, float2* const*, int,
91                               int);
92 using trsm_Z = cublasStatus_t(cublasContext*, cublasSideMode_t,
93                               cublasFillMode_t, cublasOperation_t,
94                               cublasDiagType_t, int, int, const double2*,
95                               const double2* const*, int, double2* const*, int,
96                               int);
97 }
98 
99 namespace tensorflow {
100 namespace {
101 
102 using se::cuda::ScopedActivateExecutorContext;
103 
CopyHostToDevice(OpKernelContext * context,void * dst,const void * src,uint64 bytes)104 inline bool CopyHostToDevice(OpKernelContext* context, void* dst,
105                              const void* src, uint64 bytes) {
106   auto stream = context->op_device_context()->stream();
107   se::DeviceMemoryBase wrapped_dst(dst);
108   return stream->ThenMemcpy(&wrapped_dst, src, bytes).ok();
109 }
110 
111 // A set of initialized handles to the underlying Cuda libraries used by
112 // CudaSolver. We maintain one such set of handles per unique stream.
113 struct CudaSolverHandles {
CudaSolverHandlestensorflow::__anon956e46bf0111::CudaSolverHandles114   explicit CudaSolverHandles(cudaStream_t stream) {
115     CHECK(cusolverDnCreate(&cusolver_dn_handle) == CUSOLVER_STATUS_SUCCESS)
116         << "Failed to create cuSolverDN instance.";
117     CHECK(cusolverDnSetStream(cusolver_dn_handle, stream) ==
118           CUSOLVER_STATUS_SUCCESS)
119         << "Failed to set cuSolverDN stream.";
120     CHECK(cublasCreate(&cublas_handle) == CUBLAS_STATUS_SUCCESS)
121         << "Failed to create cuBlas instance.";
122     CHECK(cublasSetStream(cublas_handle, stream) == CUBLAS_STATUS_SUCCESS)
123         << "Failed to set cuBlas stream.";
124   }
125 
~CudaSolverHandlestensorflow::__anon956e46bf0111::CudaSolverHandles126   ~CudaSolverHandles() {
127     CHECK(cublasDestroy(cublas_handle) == CUBLAS_STATUS_SUCCESS)
128         << "Failed to destroy cuBlas instance.";
129     CHECK(cusolverDnDestroy(cusolver_dn_handle) == CUSOLVER_STATUS_SUCCESS)
130         << "Failed to destroy cuSolverDN instance.";
131   }
132   cublasHandle_t cublas_handle;
133   cusolverDnHandle_t cusolver_dn_handle;
134 };
135 
136 static mutex handle_map_mutex(LINKER_INITIALIZED);
137 
138 using HandleMap =
139     std::unordered_map<cudaStream_t, std::unique_ptr<CudaSolverHandles>>;
140 
141 // Returns a singleton map used for storing initialized handles for each unique
142 // cuda stream.
GetHandleMapSingleton()143 HandleMap* GetHandleMapSingleton() {
144   static HandleMap* cm = new HandleMap;
145   return cm;
146 }
147 
148 }  // namespace
149 
150 #define TF_RETURN_IF_CUSOLVER_ERROR(expr)                      \
151   do {                                                         \
152     auto status = (expr);                                      \
153     if (TF_PREDICT_FALSE(status != CUSOLVER_STATUS_SUCCESS)) { \
154       return errors::Internal(                                 \
155           __FILE__, ":", __LINE__,                             \
156           ": cuSolverDN call failed with status =", status);   \
157     }                                                          \
158   } while (0)
159 
160 #define TF_RETURN_IF_CUBLAS_ERROR(expr)                                  \
161   do {                                                                   \
162     auto status = (expr);                                                \
163     if (TF_PREDICT_FALSE(status != CUBLAS_STATUS_SUCCESS)) {             \
164       return errors::Internal(__FILE__, ":", __LINE__,                   \
165                               ": cuBlas call failed status = ", status); \
166     }                                                                    \
167   } while (0)
168 
CudaSolver(OpKernelContext * context)169 CudaSolver::CudaSolver(OpKernelContext* context) : context_(context) {
170   mutex_lock lock(handle_map_mutex);
171   const cudaStream_t* cu_stream_ptr = CHECK_NOTNULL(
172       reinterpret_cast<const cudaStream_t*>(context->op_device_context()
173                                                 ->stream()
174                                                 ->implementation()
175                                                 ->GpuStreamMemberHack()));
176   cuda_stream_ = *cu_stream_ptr;
177   HandleMap* handle_map = CHECK_NOTNULL(GetHandleMapSingleton());
178   auto it = handle_map->find(cuda_stream_);
179   if (it == handle_map->end()) {
180     LOG(INFO) << "Creating CudaSolver handles for stream " << cuda_stream_;
181     // Previously unseen Cuda stream. Initialize a set of Cuda solver library
182     // handles for it.
183     std::unique_ptr<CudaSolverHandles> new_handles(
184         new CudaSolverHandles(cuda_stream_));
185     it =
186         handle_map->insert(std::make_pair(cuda_stream_, std::move(new_handles)))
187             .first;
188   }
189   cusolver_dn_handle_ = it->second->cusolver_dn_handle;
190   cublas_handle_ = it->second->cublas_handle;
191 }
192 
~CudaSolver()193 CudaSolver::~CudaSolver() {
194   for (const auto& tensor_ref : scratch_tensor_refs_) {
195     tensor_ref.Unref();
196   }
197 }
198 
199 // static
CheckLapackInfoAndDeleteSolverAsync(std::unique_ptr<CudaSolver> solver,const std::vector<DeviceLapackInfo> & dev_lapack_infos,std::function<void (const Status &,const std::vector<HostLapackInfo> &)> info_checker_callback)200 void CudaSolver::CheckLapackInfoAndDeleteSolverAsync(
201     std::unique_ptr<CudaSolver> solver,
202     const std::vector<DeviceLapackInfo>& dev_lapack_infos,
203     std::function<void(const Status&, const std::vector<HostLapackInfo>&)>
204         info_checker_callback) {
205   CHECK(info_checker_callback != nullptr);
206   std::vector<HostLapackInfo> host_lapack_infos;
207   if (dev_lapack_infos.empty()) {
208     info_checker_callback(Status::OK(), host_lapack_infos);
209     return;
210   }
211 
212   // Launch memcpys to copy info back from the device to the host.
213   for (const auto& dev_lapack_info : dev_lapack_infos) {
214     bool success = true;
215     auto host_copy = dev_lapack_info.CopyToHost(&success);
216     OP_REQUIRES(
217         solver->context(), success,
218         errors::Internal(
219             "Failed to launch copy of dev_lapack_info to host, debug_info = ",
220             dev_lapack_info.debug_info()));
221     host_lapack_infos.push_back(std::move(host_copy));
222   }
223 
224   // This callback checks that all batch items in all calls were processed
225   // successfully and passes status to the info_checker_callback accordingly.
226   auto* stream = solver->context()->op_device_context()->stream();
227   auto wrapped_info_checker_callback =
228       [stream](
229           CudaSolver* solver,
230           std::function<void(const Status&, const std::vector<HostLapackInfo>&)>
231               info_checker_callback,
232           std::vector<HostLapackInfo> host_lapack_infos) {
233         ScopedActivateExecutorContext scoped_activation{stream->parent()};
234         Status status;
235         for (const auto& host_lapack_info : host_lapack_infos) {
236           for (int i = 0; i < host_lapack_info.size() && status.ok(); ++i) {
237             const int info_value = host_lapack_info(i);
238             if (info_value != 0) {
239               status = errors::InvalidArgument(
240                   "Got info = ", info_value, " for batch index ", i,
241                   ", expected info = 0. Debug_info = ",
242                   host_lapack_info.debug_info());
243             }
244           }
245           if (!status.ok()) {
246             break;
247           }
248         }
249         // Delete solver to release temp tensor refs.
250         delete solver;
251 
252         // Delegate further error checking to provided functor.
253         info_checker_callback(status, host_lapack_infos);
254       };
255   // Note: An std::function cannot have unique_ptr arguments (it must be copy
256   // constructible and therefore so must its arguments). Therefore, we release
257   // solver into a raw pointer to be deleted at the end of
258   // wrapped_info_checker_callback.
259   // Release ownership of solver. It will be deleted in the cb callback.
260   auto solver_raw_ptr = solver.release();
261   auto cb =
262       std::bind(wrapped_info_checker_callback, solver_raw_ptr,
263                 std::move(info_checker_callback), std::move(host_lapack_infos));
264 
265   solver_raw_ptr->context()
266       ->device()
267       ->tensorflow_gpu_device_info()
268       ->event_mgr->ThenExecute(stream, std::move(cb));
269 }
270 
271 // static
CheckLapackInfoAndDeleteSolverAsync(std::unique_ptr<CudaSolver> solver,const std::vector<DeviceLapackInfo> & dev_lapack_info,AsyncOpKernel::DoneCallback done)272 void CudaSolver::CheckLapackInfoAndDeleteSolverAsync(
273     std::unique_ptr<CudaSolver> solver,
274     const std::vector<DeviceLapackInfo>& dev_lapack_info,
275     AsyncOpKernel::DoneCallback done) {
276   OpKernelContext* context = solver->context();
277   auto wrapped_done = [context, done](
278                           const Status& status,
279                           const std::vector<HostLapackInfo>& /* unused */) {
280     if (done != nullptr) {
281       OP_REQUIRES_OK_ASYNC(context, status, done);
282       done();
283     } else {
284       OP_REQUIRES_OK(context, status);
285     }
286   };
287   CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_lapack_info,
288                                       wrapped_done);
289 }
290 
291 // Allocates a temporary tensor. The CudaSolver object maintains a
292 // TensorReference to the underlying Tensor to prevent it from being deallocated
293 // prematurely.
allocate_scoped_tensor(DataType type,const TensorShape & shape,Tensor * out_temp)294 Status CudaSolver::allocate_scoped_tensor(DataType type,
295                                           const TensorShape& shape,
296                                           Tensor* out_temp) {
297   const Status status = context_->allocate_temp(type, shape, out_temp);
298   if (status.ok()) {
299     scratch_tensor_refs_.emplace_back(*out_temp);
300   }
301   return status;
302 }
303 
forward_input_or_allocate_scoped_tensor(gtl::ArraySlice<int> candidate_input_indices,DataType type,const TensorShape & shape,Tensor * out_temp)304 Status CudaSolver::forward_input_or_allocate_scoped_tensor(
305     gtl::ArraySlice<int> candidate_input_indices, DataType type,
306     const TensorShape& shape, Tensor* out_temp) {
307   const Status status = context_->forward_input_or_allocate_temp(
308       candidate_input_indices, type, shape, out_temp);
309   if (status.ok()) {
310     scratch_tensor_refs_.emplace_back(*out_temp);
311   }
312   return status;
313 }
314 
315 // Macro that specializes a solver method for all 4 standard
316 // numeric types.
317 #define TF_CALL_LAPACK_TYPES(m) \
318   m(float, S) m(double, D) m(std::complex<float>, C) m(std::complex<double>, Z)
319 #define TF_CALL_LAPACK_TYPES_NO_COMPLEX(m) m(float, S) m(double, D)
320 
321 // Macros to construct cusolverDn method names.
322 #define DN_SOLVER_FN(method, type_prefix) cusolverDn##type_prefix##method
323 #define DN_SOLVER_NAME(method, type_prefix) "cusolverDn" #type_prefix #method
324 #define DN_BUFSIZE_FN(method, type_prefix) \
325   cusolverDn##type_prefix##method##_bufferSize
326 
327 // Macros to construct cublas method names.
328 #define BLAS_SOLVER_FN(method, type_prefix) cublas##type_prefix##method
329 #define BLAS_SOLVER_NAME(method, type_prefix) "cublas" #type_prefix #method
330 
331 //=============================================================================
332 // Wrappers of cuSolverDN computational methods begin here.
333 //
334 // WARNING to implementers: The function signatures listed in the online docs
335 // are sometimes inaccurate, e.g., are missing 'const' on pointers
336 // to immutable arguments, while the actual headers have them as expected.
337 // Check the actual declarations in the cusolver_api.h header file.
338 //
339 // NOTE: The cuSolver functions called below appear not to be threadsafe.
340 // so we put a global lock around the calls. Since these functions only put a
341 // kernel on the shared stream, it is not a big performance hit.
342 // TODO(rmlarsen): Investigate if the locking is still needed in Cuda 9.
343 //=============================================================================
344 
345 template <typename Scalar, typename SolverFnT>
GeamImpl(SolverFnT solver,cublasHandle_t cublas_handle,cublasOperation_t transa,cublasOperation_t transb,int m,int n,const Scalar * alpha,const Scalar * A,int lda,const Scalar * beta,const Scalar * B,int ldb,Scalar * C,int ldc)346 static inline Status GeamImpl(SolverFnT solver, cublasHandle_t cublas_handle,
347                               cublasOperation_t transa,
348                               cublasOperation_t transb, int m, int n,
349                               const Scalar* alpha, /* host or device pointer */
350                               const Scalar* A, int lda,
351                               const Scalar* beta, /* host or device pointer */
352                               const Scalar* B, int ldb, Scalar* C, int ldc) {
353   mutex_lock lock(handle_map_mutex);
354   using CudaScalar = typename CUDAComplexT<Scalar>::type;
355   TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, transa, transb, m, n,
356                                    reinterpret_cast<const CudaScalar*>(alpha),
357                                    reinterpret_cast<const CudaScalar*>(A), lda,
358                                    reinterpret_cast<const CudaScalar*>(beta),
359                                    reinterpret_cast<const CudaScalar*>(B), ldb,
360                                    reinterpret_cast<CudaScalar*>(C), ldc));
361   return Status::OK();
362 }
363 
364 #define GEAM_INSTANCE(Scalar, type_prefix)                                     \
365   template <>                                                                  \
366   Status CudaSolver::Geam<Scalar>(                                             \
367       cublasOperation_t transa, cublasOperation_t transb, int m, int n,        \
368       const Scalar* alpha, /* host or device pointer */                        \
369       const Scalar* A, int lda,                                                \
370       const Scalar* beta, /* host or device pointer */                         \
371       const Scalar* B, int ldb, Scalar* C, int ldc) const {                    \
372     return GeamImpl(BLAS_SOLVER_FN(geam, type_prefix), cublas_handle_, transa, \
373                     transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc);        \
374   }
375 
376 TF_CALL_LAPACK_TYPES(GEAM_INSTANCE);
377 
378 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
PotrfImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cublasFillMode_t uplo,int n,Scalar * A,int lda,int * dev_lapack_info)379 static inline Status PotrfImpl(BufSizeFnT bufsize, SolverFnT solver,
380                                CudaSolver* cuda_solver,
381                                OpKernelContext* context,
382                                cusolverDnHandle_t cusolver_dn_handle,
383                                cublasFillMode_t uplo, int n, Scalar* A, int lda,
384                                int* dev_lapack_info) {
385   mutex_lock lock(handle_map_mutex);
386   /* Get amount of workspace memory required. */
387   int lwork;
388   TF_RETURN_IF_CUSOLVER_ERROR(
389       bufsize(cusolver_dn_handle, uplo, n, CUDAComplex(A), lda, &lwork));
390   /* Allocate device memory for workspace. */
391   auto dev_workspace =
392       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
393   /* Launch the solver kernel. */
394   TF_RETURN_IF_CUSOLVER_ERROR(solver(
395       cusolver_dn_handle, uplo, n, CUDAComplex(A), lda,
396       CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));
397   return Status::OK();
398 }
399 
400 #define POTRF_INSTANCE(Scalar, type_prefix)                                  \
401   template <>                                                                \
402   Status CudaSolver::Potrf<Scalar>(cublasFillMode_t uplo, int n, Scalar* A,  \
403                                    int lda, int* dev_lapack_info) {          \
404     return PotrfImpl(DN_BUFSIZE_FN(potrf, type_prefix),                      \
405                      DN_SOLVER_FN(potrf, type_prefix), this, context_,       \
406                      cusolver_dn_handle_, uplo, n, A, lda, dev_lapack_info); \
407   }
408 
409 TF_CALL_LAPACK_TYPES(POTRF_INSTANCE);
410 
411 #if CUDA_VERSION >= 9020
412 template <typename Scalar, typename SolverFnT>
PotrfBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cublasFillMode_t uplo,int n,const Scalar * const host_a_dev_ptrs[],int lda,DeviceLapackInfo * dev_lapack_info,int batch_size)413 static inline Status PotrfBatchedImpl(
414     SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
415     cusolverDnHandle_t cusolver_dn_handle, cublasFillMode_t uplo, int n,
416     const Scalar* const host_a_dev_ptrs[], int lda,
417     DeviceLapackInfo* dev_lapack_info, int batch_size) {
418   mutex_lock lock(handle_map_mutex);
419   using CudaScalar = typename CUDAComplexT<Scalar>::type;
420   ScratchSpace<uint8> dev_a_dev_ptrs =
421       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
422                                           /* on_host */ false);
423   if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
424                         host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
425     return errors::Internal("PotrfBatched: failed to copy pointers to device");
426   }
427   TF_RETURN_IF_CUSOLVER_ERROR(
428       solver(cusolver_dn_handle, uplo, n,
429              reinterpret_cast<CudaScalar**>(dev_a_dev_ptrs.mutable_data()), lda,
430              dev_lapack_info->mutable_data(), batch_size));
431   return Status::OK();
432 }
433 
434 #define POTRF_BATCHED_INSTANCE(Scalar, type_prefix)                        \
435   template <>                                                              \
436   Status CudaSolver::PotrfBatched(                                         \
437       cublasFillMode_t uplo, int n, const Scalar* const host_a_dev_ptrs[], \
438       int lda, DeviceLapackInfo* dev_lapack_info, int batch_size) {        \
439     return PotrfBatchedImpl(DN_SOLVER_FN(potrfBatched, type_prefix), this, \
440                             context_, cusolver_dn_handle_, uplo, n,        \
441                             host_a_dev_ptrs, lda, dev_lapack_info,         \
442                             batch_size);                                   \
443   }
444 
445 TF_CALL_LAPACK_TYPES(POTRF_BATCHED_INSTANCE);
446 #endif  // CUDA_VERSION >= 9020
447 
448 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
GetrfImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,int m,int n,Scalar * A,int lda,int * dev_pivots,int * dev_lapack_info)449 static inline Status GetrfImpl(BufSizeFnT bufsize, SolverFnT solver,
450                                CudaSolver* cuda_solver,
451                                OpKernelContext* context,
452                                cusolverDnHandle_t cusolver_dn_handle, int m,
453                                int n, Scalar* A, int lda, int* dev_pivots,
454                                int* dev_lapack_info) {
455   mutex_lock lock(handle_map_mutex);
456   /* Get amount of workspace memory required. */
457   int lwork;
458   TF_RETURN_IF_CUSOLVER_ERROR(
459       bufsize(cusolver_dn_handle, m, n, CUDAComplex(A), lda, &lwork));
460   /* Allocate device memory for workspace. */
461   auto dev_workspace =
462       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
463   /* Launch the solver kernel. */
464   TF_RETURN_IF_CUSOLVER_ERROR(solver(
465       cusolver_dn_handle, m, n, CUDAComplex(A), lda,
466       CUDAComplex(dev_workspace.mutable_data()), dev_pivots, dev_lapack_info));
467   return Status::OK();
468 }
469 
470 #define GETRF_INSTANCE(Scalar, type_prefix)                                 \
471   template <>                                                               \
472   Status CudaSolver::Getrf<Scalar>(int m, int n, Scalar* A, int lda,        \
473                                    int* dev_pivots, int* dev_lapack_info) { \
474     return GetrfImpl(DN_BUFSIZE_FN(getrf, type_prefix),                     \
475                      DN_SOLVER_FN(getrf, type_prefix), this, context_,      \
476                      cusolver_dn_handle_, m, n, A, lda, dev_pivots,         \
477                      dev_lapack_info);                                      \
478   }
479 
480 TF_CALL_LAPACK_TYPES(GETRF_INSTANCE);
481 
482 template <typename Scalar, typename SolverFnT>
GetrsImpl(SolverFnT solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cublasOperation_t trans,int n,int nrhs,const Scalar * A,int lda,const int * pivots,Scalar * B,int ldb,int * dev_lapack_info)483 static inline Status GetrsImpl(SolverFnT solver, OpKernelContext* context,
484                                cusolverDnHandle_t cusolver_dn_handle,
485                                cublasOperation_t trans, int n, int nrhs,
486                                const Scalar* A, int lda, const int* pivots,
487                                Scalar* B, int ldb, int* dev_lapack_info) {
488   mutex_lock lock(handle_map_mutex);
489   /* Launch the solver kernel. */
490   TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, trans, n, nrhs,
491                                      CUDAComplex(A), lda, pivots,
492                                      CUDAComplex(B), ldb, dev_lapack_info));
493   return Status::OK();
494 }
495 
496 #define GETRS_INSTANCE(Scalar, type_prefix)                                  \
497   template <>                                                                \
498   Status CudaSolver::Getrs<Scalar>(                                          \
499       cublasOperation_t trans, int n, int nrhs, const Scalar* A, int lda,    \
500       const int* pivots, Scalar* B, int ldb, int* dev_lapack_info) const {   \
501     return GetrsImpl(DN_SOLVER_FN(getrs, type_prefix), context_,             \
502                      cusolver_dn_handle_, trans, n, nrhs, A, lda, pivots, B, \
503                      ldb, dev_lapack_info);                                  \
504   }
505 
506 TF_CALL_LAPACK_TYPES(GETRS_INSTANCE);
507 
508 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
GeqrfImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,int m,int n,Scalar * A,int lda,Scalar * tau,int * dev_lapack_info)509 static inline Status GeqrfImpl(BufSizeFnT bufsize, SolverFnT solver,
510                                CudaSolver* cuda_solver,
511                                OpKernelContext* context,
512                                cusolverDnHandle_t cusolver_dn_handle, int m,
513                                int n, Scalar* A, int lda, Scalar* tau,
514                                int* dev_lapack_info) {
515   mutex_lock lock(handle_map_mutex);
516   /* Get amount of workspace memory required. */
517   int lwork;
518   TF_RETURN_IF_CUSOLVER_ERROR(
519       bufsize(cusolver_dn_handle, m, n, CUDAComplex(A), lda, &lwork));
520   /* Allocate device memory for workspace. */
521   auto dev_workspace =
522       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
523   /* Launch the solver kernel. */
524   TF_RETURN_IF_CUSOLVER_ERROR(solver(
525       cusolver_dn_handle, m, n, CUDAComplex(A), lda, CUDAComplex(tau),
526       CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));
527   return Status::OK();
528 }
529 
530 #define GEQRF_INSTANCE(Scalar, type_prefix)                                    \
531   template <>                                                                  \
532   Status CudaSolver::Geqrf<Scalar>(int m, int n, Scalar* A, int lda,           \
533                                    Scalar* tau, int* dev_lapack_info) {        \
534     return GeqrfImpl(DN_BUFSIZE_FN(geqrf, type_prefix),                        \
535                      DN_SOLVER_FN(geqrf, type_prefix), this, context_,         \
536                      cusolver_dn_handle_, m, n, A, lda, tau, dev_lapack_info); \
537   }
538 
539 TF_CALL_LAPACK_TYPES(GEQRF_INSTANCE);
540 
541 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
UnmqrImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cublasSideMode_t side,cublasOperation_t trans,int m,int n,int k,const Scalar * dev_a,int lda,const Scalar * dev_tau,Scalar * dev_c,int ldc,int * dev_lapack_info)542 static inline Status UnmqrImpl(BufSizeFnT bufsize, SolverFnT solver,
543                                CudaSolver* cuda_solver,
544                                OpKernelContext* context,
545                                cusolverDnHandle_t cusolver_dn_handle,
546                                cublasSideMode_t side, cublasOperation_t trans,
547                                int m, int n, int k, const Scalar* dev_a,
548                                int lda, const Scalar* dev_tau, Scalar* dev_c,
549                                int ldc, int* dev_lapack_info) {
550   mutex_lock lock(handle_map_mutex);
551   /* Get amount of workspace memory required. */
552   int lwork;
553   TF_RETURN_IF_CUSOLVER_ERROR(
554       bufsize(cusolver_dn_handle, side, trans, m, n, k, CUDAComplex(dev_a), lda,
555               CUDAComplex(dev_tau), CUDAComplex(dev_c), ldc, &lwork));
556   /* Allocate device memory for workspace. */
557   auto dev_workspace =
558       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
559   /* Launch the solver kernel. */
560   TF_RETURN_IF_CUSOLVER_ERROR(solver(
561       cusolver_dn_handle, side, trans, m, n, k, CUDAComplex(dev_a), lda,
562       CUDAComplex(dev_tau), CUDAComplex(dev_c), ldc,
563       CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));
564   return Status::OK();
565 }
566 
567 // Unfortunately the LAPACK function name differs for the real and complex case
568 // (complex ones are prefixed with "UN" for "unitary"), so we instantiate each
569 // one separately.
570 #define UNMQR_INSTANCE(Scalar, function_prefix, type_prefix)                  \
571   template <>                                                                 \
572   Status CudaSolver::Unmqr(cublasSideMode_t side, cublasOperation_t trans,    \
573                            int m, int n, int k, const Scalar* dev_a, int lda, \
574                            const Scalar* dev_tau, Scalar* dev_c, int ldc,     \
575                            int* dev_lapack_info) {                            \
576     return UnmqrImpl(DN_BUFSIZE_FN(function_prefix##mqr, type_prefix),        \
577                      DN_SOLVER_FN(function_prefix##mqr, type_prefix), this,   \
578                      context_, cusolver_dn_handle_, side, trans, m, n, k,     \
579                      dev_a, lda, dev_tau, dev_c, ldc, dev_lapack_info);       \
580   }
581 
582 UNMQR_INSTANCE(float, or, S);
583 UNMQR_INSTANCE(double, or, D);
584 UNMQR_INSTANCE(complex64, un, C);
585 UNMQR_INSTANCE(complex128, un, Z);
586 
587 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
UngqrImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,int m,int n,int k,Scalar * dev_a,int lda,const Scalar * dev_tau,int * dev_lapack_info)588 static inline Status UngqrImpl(BufSizeFnT bufsize, SolverFnT solver,
589                                CudaSolver* cuda_solver,
590                                OpKernelContext* context,
591                                cusolverDnHandle_t cusolver_dn_handle, int m,
592                                int n, int k, Scalar* dev_a, int lda,
593                                const Scalar* dev_tau, int* dev_lapack_info) {
594   mutex_lock lock(handle_map_mutex);
595   /* Get amount of workspace memory required. */
596   int lwork;
597   TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, k,
598                                       CUDAComplex(dev_a), lda,
599                                       CUDAComplex(dev_tau), &lwork));
600   /* Allocate device memory for workspace. */
601   auto dev_workspace =
602       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
603   /* Launch the solver kernel. */
604   TF_RETURN_IF_CUSOLVER_ERROR(
605       solver(cusolver_dn_handle, m, n, k, CUDAComplex(dev_a), lda,
606              CUDAComplex(dev_tau), CUDAComplex(dev_workspace.mutable_data()),
607              lwork, dev_lapack_info));
608   return Status::OK();
609 }
610 
611 #define UNGQR_INSTANCE(Scalar, function_prefix, type_prefix)                \
612   template <>                                                               \
613   Status CudaSolver::Ungqr(int m, int n, int k, Scalar* dev_a, int lda,     \
614                            const Scalar* dev_tau, int* dev_lapack_info) {   \
615     return UngqrImpl(DN_BUFSIZE_FN(function_prefix##gqr, type_prefix),      \
616                      DN_SOLVER_FN(function_prefix##gqr, type_prefix), this, \
617                      context_, cusolver_dn_handle_, m, n, k, dev_a, lda,    \
618                      dev_tau, dev_lapack_info);                             \
619   }
620 
621 UNGQR_INSTANCE(float, or, S);
622 UNGQR_INSTANCE(double, or, D);
623 UNGQR_INSTANCE(complex64, un, C);
624 UNGQR_INSTANCE(complex128, un, Z);
625 
626 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
HeevdImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,Scalar * dev_A,int lda,typename Eigen::NumTraits<Scalar>::Real * dev_W,int * dev_lapack_info)627 static inline Status HeevdImpl(BufSizeFnT bufsize, SolverFnT solver,
628                                CudaSolver* cuda_solver,
629                                OpKernelContext* context,
630                                cusolverDnHandle_t cusolver_dn_handle,
631                                cusolverEigMode_t jobz, cublasFillMode_t uplo,
632                                int n, Scalar* dev_A, int lda,
633                                typename Eigen::NumTraits<Scalar>::Real* dev_W,
634                                int* dev_lapack_info) {
635   mutex_lock lock(handle_map_mutex);
636   /* Get amount of workspace memory required. */
637   int lwork;
638   TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, jobz, uplo, n,
639                                       CUDAComplex(dev_A), lda,
640                                       CUDAComplex(dev_W), &lwork));
641   /* Allocate device memory for workspace. */
642   auto dev_workspace =
643       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
644   /* Launch the solver kernel. */
645   TF_RETURN_IF_CUSOLVER_ERROR(
646       solver(cusolver_dn_handle, jobz, uplo, n, CUDAComplex(dev_A), lda,
647              CUDAComplex(dev_W), CUDAComplex(dev_workspace.mutable_data()),
648              lwork, dev_lapack_info));
649   return Status::OK();
650 }
651 
652 #define HEEVD_INSTANCE(Scalar, function_prefix, type_prefix)                   \
653   template <>                                                                  \
654   Status CudaSolver::Heevd(cusolverEigMode_t jobz, cublasFillMode_t uplo,      \
655                            int n, Scalar* dev_A, int lda,                      \
656                            typename Eigen::NumTraits<Scalar>::Real* dev_W,     \
657                            int* dev_lapack_info) {                             \
658     return HeevdImpl(DN_BUFSIZE_FN(function_prefix##evd, type_prefix),         \
659                      DN_SOLVER_FN(function_prefix##evd, type_prefix), this,    \
660                      context_, cusolver_dn_handle_, jobz, uplo, n, dev_A, lda, \
661                      dev_W, dev_lapack_info);                                  \
662   }
663 
664 HEEVD_INSTANCE(float, sy, S);
665 HEEVD_INSTANCE(double, sy, D);
666 HEEVD_INSTANCE(complex64, he, C);
667 HEEVD_INSTANCE(complex128, he, Z);
668 
669 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
GesvdImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,signed char jobu,signed char jobvt,int m,int n,Scalar * A,int lda,Scalar * S,Scalar * U,int ldu,Scalar * VT,int ldvt,int * dev_lapack_info)670 static inline Status GesvdImpl(
671     BufSizeFnT bufsize, SolverFnT solver, CudaSolver* cuda_solver,
672     OpKernelContext* context, cusolverDnHandle_t cusolver_dn_handle,
673     signed char jobu, signed char jobvt, int m, int n, Scalar* A, int lda,
674     Scalar* S, Scalar* U, int ldu, Scalar* VT, int ldvt, int* dev_lapack_info) {
675   mutex_lock lock(handle_map_mutex);
676   /* Get amount of workspace memory required. */
677   int lwork;
678   TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, &lwork));
679   /* Allocate device memory for workspace. */
680   auto dev_workspace =
681       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
682   TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, jobu, jobvt, m, n,
683                                      CUDAComplex(A), lda, S, CUDAComplex(U),
684                                      ldu, CUDAComplex(VT), ldvt,
685                                      CUDAComplex(dev_workspace.mutable_data()),
686                                      lwork, nullptr, dev_lapack_info));
687   return Status::OK();
688 }
689 
690 #define GESVD_INSTANCE(Scalar, type_prefix)                              \
691   template <>                                                            \
692   Status CudaSolver::Gesvd<Scalar>(                                      \
693       signed char jobu, signed char jobvt, int m, int n, Scalar* dev_A,  \
694       int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT,    \
695       int ldvt, int* dev_lapack_info) {                                  \
696     return GesvdImpl(DN_BUFSIZE_FN(gesvd, type_prefix),                  \
697                      DN_SOLVER_FN(gesvd, type_prefix), this, context_,   \
698                      cusolver_dn_handle_, jobu, jobvt, m, n, dev_A, lda, \
699                      dev_S, dev_U, ldu, dev_VT, ldvt, dev_lapack_info);  \
700   }
701 
702 TF_CALL_LAPACK_TYPES_NO_COMPLEX(GESVD_INSTANCE);
703 
704 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
GesvdjBatchedImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cusolverEigMode_t jobz,int m,int n,Scalar * A,int lda,Scalar * S,Scalar * U,int ldu,Scalar * V,int ldv,int * dev_lapack_info,int batch_size)705 static inline Status GesvdjBatchedImpl(BufSizeFnT bufsize, SolverFnT solver,
706                                        CudaSolver* cuda_solver,
707                                        OpKernelContext* context,
708                                        cusolverDnHandle_t cusolver_dn_handle,
709                                        cusolverEigMode_t jobz, int m, int n,
710                                        Scalar* A, int lda, Scalar* S, Scalar* U,
711                                        int ldu, Scalar* V, int ldv,
712                                        int* dev_lapack_info, int batch_size) {
713   mutex_lock lock(handle_map_mutex);
714   /* Get amount of workspace memory required. */
715   int lwork;
716   /* Default parameters for gesvdj and gesvdjBatched. */
717   gesvdjInfo_t svdj_info;
718   TF_RETURN_IF_CUSOLVER_ERROR(cusolverDnCreateGesvdjInfo(&svdj_info));
719   TF_RETURN_IF_CUSOLVER_ERROR(bufsize(
720       cusolver_dn_handle, jobz, m, n, CUDAComplex(A), lda, S, CUDAComplex(U),
721       ldu, CUDAComplex(V), ldv, &lwork, svdj_info, batch_size));
722   /* Allocate device memory for workspace. */
723   auto dev_workspace =
724       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
725   TF_RETURN_IF_CUSOLVER_ERROR(solver(
726       cusolver_dn_handle, jobz, m, n, CUDAComplex(A), lda, S, CUDAComplex(U),
727       ldu, CUDAComplex(V), ldv, CUDAComplex(dev_workspace.mutable_data()),
728       lwork, dev_lapack_info, svdj_info, batch_size));
729   TF_RETURN_IF_CUSOLVER_ERROR(cusolverDnDestroyGesvdjInfo(svdj_info));
730   return Status::OK();
731 }
732 
733 #define GESVDJBATCHED_INSTANCE(Scalar, type_prefix)                            \
734   template <>                                                                  \
735   Status CudaSolver::GesvdjBatched<Scalar>(                                    \
736       cusolverEigMode_t jobz, int m, int n, Scalar* dev_A, int lda,            \
737       Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_V, int ldv,           \
738       int* dev_lapack_info, int batch_size) {                                  \
739     return GesvdjBatchedImpl(DN_BUFSIZE_FN(gesvdjBatched, type_prefix),        \
740                              DN_SOLVER_FN(gesvdjBatched, type_prefix), this,   \
741                              context_, cusolver_dn_handle_, jobz, m, n, dev_A, \
742                              lda, dev_S, dev_U, ldu, dev_V, ldv,               \
743                              dev_lapack_info, batch_size);                     \
744   }
745 
746 TF_CALL_LAPACK_TYPES_NO_COMPLEX(GESVDJBATCHED_INSTANCE);
747 
748 //=============================================================================
749 // Wrappers of cuBlas computational methods begin here.
750 //
751 // WARNING to implementers: The function signatures listed in the online docs
752 // are sometimes inaccurate, e.g., are missing 'const' on pointers
753 // to immutable arguments, while the actual headers have them as expected.
754 // Check the actual declarations in the cublas_api.h header file.
755 //=============================================================================
756 template <typename Scalar, typename SolverFnT>
GetrfBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,int n,const Scalar * const host_a_dev_ptrs[],int lda,int * dev_pivots,DeviceLapackInfo * dev_lapack_info,int batch_size)757 static inline Status GetrfBatchedImpl(SolverFnT solver, CudaSolver* cuda_solver,
758                                       OpKernelContext* context,
759                                       cublasHandle_t cublas_handle, int n,
760                                       const Scalar* const host_a_dev_ptrs[],
761                                       int lda, int* dev_pivots,
762                                       DeviceLapackInfo* dev_lapack_info,
763                                       int batch_size) {
764   mutex_lock lock(handle_map_mutex);
765   using CudaScalar = typename CUDAComplexT<Scalar>::type;
766   ScratchSpace<uint8> dev_a_dev_ptrs =
767       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
768                                           /* on_host */ false);
769   if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
770                         host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
771     return errors::Internal("GetrfBatched: failed to copy pointers to device");
772   }
773   TF_RETURN_IF_CUBLAS_ERROR(
774       solver(cublas_handle, n,
775              reinterpret_cast<CudaScalar**>(dev_a_dev_ptrs.mutable_data()), lda,
776              dev_pivots, dev_lapack_info->mutable_data(), batch_size));
777   return Status::OK();
778 }
779 
780 #define GETRF_BATCHED_INSTANCE(Scalar, type_prefix)                            \
781   template <>                                                                  \
782   Status CudaSolver::GetrfBatched(                                             \
783       int n, const Scalar* const host_a_dev_ptrs[], int lda, int* dev_pivots,  \
784       DeviceLapackInfo* dev_lapack_info, int batch_size) {                     \
785     return GetrfBatchedImpl(BLAS_SOLVER_FN(getrfBatched, type_prefix), this,   \
786                             context_, cublas_handle_, n, host_a_dev_ptrs, lda, \
787                             dev_pivots, dev_lapack_info, batch_size);          \
788   }
789 
790 TF_CALL_LAPACK_TYPES(GETRF_BATCHED_INSTANCE);
791 
792 template <typename Scalar, typename SolverFnT>
GetrsBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,cublasOperation_t trans,int n,int nrhs,const Scalar * const host_a_dev_ptrs[],int lda,const int * dev_pivots,const Scalar * const host_b_dev_ptrs[],int ldb,int * host_lapack_info,int batch_size)793 static inline Status GetrsBatchedImpl(
794     SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
795     cublasHandle_t cublas_handle, cublasOperation_t trans, int n, int nrhs,
796     const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots,
797     const Scalar* const host_b_dev_ptrs[], int ldb, int* host_lapack_info,
798     int batch_size) {
799   mutex_lock lock(handle_map_mutex);
800   using CudaScalar = typename CUDAComplexT<Scalar>::type;
801   ScratchSpace<uint8> dev_a_dev_ptrs =
802       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
803                                           /* on_host */ false);
804   ScratchSpace<uint8> dev_b_dev_ptrs =
805       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
806                                           /* on_host */ false);
807   if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
808                         host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
809     return errors::Internal("GetrsBatched: failed to copy pointers to device");
810   }
811   if (!CopyHostToDevice(context, dev_b_dev_ptrs.mutable_data() /* dest */,
812                         host_b_dev_ptrs /* source */, dev_b_dev_ptrs.bytes())) {
813     return errors::Internal("GetrsBatched: failed to copy pointers to device");
814   }
815   TF_RETURN_IF_CUBLAS_ERROR(solver(
816       cublas_handle, trans, n, nrhs,
817       reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()), lda,
818       dev_pivots, reinterpret_cast<CudaScalar**>(dev_b_dev_ptrs.mutable_data()),
819       ldb, host_lapack_info, batch_size));
820   return Status::OK();
821 }
822 
823 #define GETRS_BATCHED_INSTANCE(Scalar, type_prefix)                            \
824   template <>                                                                  \
825   Status CudaSolver::GetrsBatched(                                             \
826       cublasOperation_t trans, int n, int nrhs,                                \
827       const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots,   \
828       const Scalar* const host_b_dev_ptrs[], int ldb, int* host_lapack_info,   \
829       int batch_size) {                                                        \
830     return GetrsBatchedImpl(reinterpret_cast<getrs_##type_prefix*>(            \
831                                 BLAS_SOLVER_FN(getrsBatched, type_prefix)),    \
832                             this, context_, cublas_handle_, trans, n, nrhs,    \
833                             host_a_dev_ptrs, lda, dev_pivots, host_b_dev_ptrs, \
834                             ldb, host_lapack_info, batch_size);                \
835   }
836 
837 TF_CALL_LAPACK_TYPES(GETRS_BATCHED_INSTANCE);
838 
839 template <typename Scalar, typename SolverFnT>
GetriBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,int n,const Scalar * const host_a_dev_ptrs[],int lda,const int * dev_pivots,const Scalar * const host_a_inv_dev_ptrs[],int ldainv,DeviceLapackInfo * dev_lapack_info,int batch_size)840 static inline Status GetriBatchedImpl(
841     SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
842     cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[],
843     int lda, const int* dev_pivots, const Scalar* const host_a_inv_dev_ptrs[],
844     int ldainv, DeviceLapackInfo* dev_lapack_info, int batch_size) {
845   mutex_lock lock(handle_map_mutex);
846   using CudaScalar = typename CUDAComplexT<Scalar>::type;
847   ScratchSpace<uint8> dev_a_dev_ptrs =
848       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
849                                           /* on_host */ false);
850   ScratchSpace<uint8> dev_a_inv_dev_ptrs = cuda_solver->GetScratchSpace<uint8>(
851       sizeof(CudaScalar*) * batch_size, "", /* on_host */ false);
852   if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
853                         host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes()) ||
854       !CopyHostToDevice(context, dev_a_inv_dev_ptrs.mutable_data(),
855                         host_a_inv_dev_ptrs, dev_a_inv_dev_ptrs.bytes())) {
856     return errors::Internal("GetriBatched: failed to copy pointers to device");
857   }
858   TF_RETURN_IF_CUBLAS_ERROR(
859       solver(cublas_handle, n,
860              reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()),
861              lda, dev_pivots,
862              reinterpret_cast<CudaScalar**>(dev_a_inv_dev_ptrs.mutable_data()),
863              ldainv, dev_lapack_info->mutable_data(), batch_size));
864   return Status::OK();
865 }
866 
867 #define GETRI_BATCHED_INSTANCE(Scalar, type_prefix)                          \
868   template <>                                                                \
869   Status CudaSolver::GetriBatched(                                           \
870       int n, const Scalar* const host_a_dev_ptrs[], int lda,                 \
871       const int* dev_pivots, const Scalar* const host_a_inv_dev_ptrs[],      \
872       int ldainv, DeviceLapackInfo* dev_lapack_info, int batch_size) {       \
873     return GetriBatchedImpl(                                                 \
874         reinterpret_cast<getri_##type_prefix*>(                              \
875             BLAS_SOLVER_FN(getriBatched, type_prefix)),                      \
876         this, context_, cublas_handle_, n, host_a_dev_ptrs, lda, dev_pivots, \
877         host_a_inv_dev_ptrs, ldainv, dev_lapack_info, batch_size);           \
878   }
879 
880 TF_CALL_LAPACK_TYPES(GETRI_BATCHED_INSTANCE);
881 
882 template <typename Scalar, typename SolverFnT>
MatInvBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,int n,const Scalar * const host_a_dev_ptrs[],int lda,const Scalar * const host_a_inv_dev_ptrs[],int ldainv,DeviceLapackInfo * dev_lapack_info,int batch_size)883 static inline Status MatInvBatchedImpl(
884     SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
885     cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[],
886     int lda, const Scalar* const host_a_inv_dev_ptrs[], int ldainv,
887     DeviceLapackInfo* dev_lapack_info, int batch_size) {
888   mutex_lock lock(handle_map_mutex);
889   using CudaScalar = typename CUDAComplexT<Scalar>::type;
890   ScratchSpace<uint8> dev_a_dev_ptrs =
891       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
892                                           /* on_host */ false);
893   ScratchSpace<uint8> dev_a_inv_dev_ptrs = cuda_solver->GetScratchSpace<uint8>(
894       sizeof(CudaScalar*) * batch_size, "", /* on_host */ false);
895   if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
896                         host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes()) ||
897       !CopyHostToDevice(context, dev_a_inv_dev_ptrs.mutable_data(),
898                         host_a_inv_dev_ptrs, dev_a_inv_dev_ptrs.bytes())) {
899     return errors::Internal("MatInvBatched: failed to copy pointers to device");
900   }
901   TF_RETURN_IF_CUBLAS_ERROR(solver(
902       cublas_handle, n,
903       reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()), lda,
904       reinterpret_cast<CudaScalar**>(dev_a_inv_dev_ptrs.mutable_data()), ldainv,
905       dev_lapack_info->mutable_data(), batch_size));
906   return Status::OK();
907 }
908 
909 #define MATINV_BATCHED_INSTANCE(Scalar, type_prefix)                          \
910   template <>                                                                 \
911   Status CudaSolver::MatInvBatched(                                           \
912       int n, const Scalar* const host_a_dev_ptrs[], int lda,                  \
913       const Scalar* const host_a_inv_dev_ptrs[], int ldainv,                  \
914       DeviceLapackInfo* dev_lapack_info, int batch_size) {                    \
915     return MatInvBatchedImpl(reinterpret_cast<matinv_##type_prefix*>(         \
916                                  BLAS_SOLVER_FN(matinvBatched, type_prefix)), \
917                              this, context_, cublas_handle_, n,               \
918                              host_a_dev_ptrs, lda, host_a_inv_dev_ptrs,       \
919                              ldainv, dev_lapack_info, batch_size);            \
920   }
921 
922 TF_CALL_LAPACK_TYPES(MATINV_BATCHED_INSTANCE);
923 
924 template <typename Scalar, typename SolverFnT>
TrsmImpl(SolverFnT solver,cublasHandle_t cublas_handle,cublasSideMode_t side,cublasFillMode_t uplo,cublasOperation_t trans,cublasDiagType_t diag,int m,int n,const Scalar * alpha,const Scalar * A,int lda,Scalar * B,int ldb)925 static inline Status TrsmImpl(SolverFnT solver, cublasHandle_t cublas_handle,
926                               cublasSideMode_t side, cublasFillMode_t uplo,
927                               cublasOperation_t trans, cublasDiagType_t diag,
928                               int m, int n,
929                               const Scalar* alpha, /* host or device pointer */
930                               const Scalar* A, int lda, Scalar* B, int ldb) {
931   mutex_lock lock(handle_map_mutex);
932   using CudaScalar = typename CUDAComplexT<Scalar>::type;
933   TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, side, uplo, trans, diag, m, n,
934                                    reinterpret_cast<const CudaScalar*>(alpha),
935                                    reinterpret_cast<const CudaScalar*>(A), lda,
936                                    reinterpret_cast<CudaScalar*>(B), ldb));
937   return Status::OK();
938 }
939 
940 #define TRSM_INSTANCE(Scalar, type_prefix)                                   \
941   template <>                                                                \
942   Status CudaSolver::Trsm<Scalar>(                                           \
943       cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, \
944       cublasDiagType_t diag, int m, int n,                                   \
945       const Scalar* alpha, /* host or device pointer */                      \
946       const Scalar* A, int lda, Scalar* B, int ldb) {                        \
947     return TrsmImpl(BLAS_SOLVER_FN(trsm, type_prefix), cublas_handle_, side, \
948                     uplo, trans, diag, m, n, alpha, A, lda, B, ldb);         \
949   }
950 
951 TF_CALL_LAPACK_TYPES(TRSM_INSTANCE);
952 
953 template <typename Scalar, typename SolverFnT>
TrsvImpl(SolverFnT solver,cublasHandle_t cublas_handle,cublasFillMode_t uplo,cublasOperation_t trans,cublasDiagType_t diag,int n,const Scalar * A,int lda,Scalar * x,int incx)954 static inline Status TrsvImpl(SolverFnT solver, cublasHandle_t cublas_handle,
955                               cublasFillMode_t uplo, cublasOperation_t trans,
956                               cublasDiagType_t diag, int n, const Scalar* A,
957                               int lda, Scalar* x, int incx) {
958   mutex_lock lock(handle_map_mutex);
959   using CudaScalar = typename CUDAComplexT<Scalar>::type;
960   TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, uplo, trans, diag, n,
961                                    reinterpret_cast<const CudaScalar*>(A), lda,
962                                    reinterpret_cast<CudaScalar*>(x), incx));
963   return Status::OK();
964 }
965 
966 #define TRSV_INSTANCE(Scalar, type_prefix)                                   \
967   template <>                                                                \
968   Status CudaSolver::Trsv<Scalar>(                                           \
969       cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, \
970       int n, const Scalar* A, int lda, Scalar* x, int incx) {                \
971     return TrsvImpl(BLAS_SOLVER_FN(trsv, type_prefix), cublas_handle_, uplo, \
972                     trans, diag, n, A, lda, x, incx);                        \
973   }
974 
975 TF_CALL_LAPACK_TYPES(TRSV_INSTANCE);
976 
977 template <typename Scalar, typename SolverFnT>
TrsmBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,cublasSideMode_t side,cublasFillMode_t uplo,cublasOperation_t trans,cublasDiagType_t diag,int m,int n,const Scalar * alpha,const Scalar * const host_a_dev_ptrs[],int lda,Scalar * host_b_dev_ptrs[],int ldb,int batch_size)978 static inline Status TrsmBatchedImpl(
979     SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
980     cublasHandle_t cublas_handle, cublasSideMode_t side, cublasFillMode_t uplo,
981     cublasOperation_t trans, cublasDiagType_t diag, int m, int n,
982     const Scalar* alpha, const Scalar* const host_a_dev_ptrs[], int lda,
983     Scalar* host_b_dev_ptrs[], int ldb, int batch_size) {
984   mutex_lock lock(handle_map_mutex);
985   using CudaScalar = typename CUDAComplexT<Scalar>::type;
986   ScratchSpace<uint8> dev_a_dev_ptrs =
987       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
988                                           /* on_host */ false);
989   ScratchSpace<uint8> dev_b_dev_ptrs =
990       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
991                                           /* on_host */ false);
992   if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
993                         host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
994     return errors::Internal("TrsmBatched: failed to copy pointers to device");
995   }
996   if (!CopyHostToDevice(context, dev_b_dev_ptrs.mutable_data() /* dest */,
997                         host_b_dev_ptrs /* source */, dev_b_dev_ptrs.bytes())) {
998     return errors::Internal("TrsmBatched: failed to copy pointers to device");
999   }
1000   TF_RETURN_IF_CUBLAS_ERROR(
1001       solver(cublas_handle, side, uplo, trans, diag, m, n,
1002              reinterpret_cast<const CudaScalar*>(alpha),
1003              reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()),
1004              lda, reinterpret_cast<CudaScalar**>(dev_b_dev_ptrs.mutable_data()),
1005              ldb, batch_size));
1006   return Status::OK();
1007 }
1008 
1009 #define TRSM_BATCHED_INSTANCE(Scalar, type_prefix)                            \
1010   template <>                                                                 \
1011   Status CudaSolver::TrsmBatched(                                             \
1012       cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans,  \
1013       cublasDiagType_t diag, int m, int n, const Scalar* alpha,               \
1014       const Scalar* const dev_Aarray[], int lda, Scalar* dev_Barray[],        \
1015       int ldb, int batch_size) {                                              \
1016     return TrsmBatchedImpl(reinterpret_cast<trsm_##type_prefix*>(             \
1017                                BLAS_SOLVER_FN(trsmBatched, type_prefix)),     \
1018                            this, context_, cublas_handle_, side, uplo, trans, \
1019                            diag, m, n, alpha, dev_Aarray, lda, dev_Barray,    \
1020                            ldb, batch_size);                                  \
1021   }
1022 
1023 TF_CALL_LAPACK_TYPES(TRSM_BATCHED_INSTANCE);
1024 
1025 }  // namespace tensorflow
1026 
1027 #endif  // GOOGLE_CUDA
1028