• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3    Licensed under the Apache License, Version 2.0 (the "License");
4    you may not use this file except in compliance with the License.
5    You may obtain a copy of the License at
6 
7    http://www.apache.org/licenses/LICENSE-2.0
8 
9    Unless required by applicable law or agreed to in writing, software
10    distributed under the License is distributed on an "AS IS" BASIS,
11    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12    See the License for the specific language governing permissions and
13    limitations under the License.
14    ==============================================================================
15 */
16 #ifdef GOOGLE_CUDA
17 #include "tensorflow/core/kernels/cuda_solvers.h"
18 
19 #include <chrono>
20 #include <complex>
21 #include <unordered_map>
22 #include <vector>
23 
24 #include "cuda/include/cublas_v2.h"
25 #include "cuda/include/cusolverDn.h"
26 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/lib/core/blocking_counter.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/lib/core/stringpiece.h"
32 #include "tensorflow/core/lib/gtl/inlined_vector.h"
33 #include "tensorflow/core/platform/cuda.h"
34 #include "tensorflow/core/platform/mutex.h"
35 #include "tensorflow/core/platform/stream_executor.h"
36 #include "tensorflow/core/platform/types.h"
37 
38 // The CUDA cublas_api.h API contains const-correctness errors. Instead of
39 // casting away constness on our data, we instead reinterpret the CuBLAS
40 // functions as what they were clearly meant to be, and thus we can call
41 // the functions naturally.
42 //
43 // (The error is that input-only arrays are bound to parameter types
44 // "const T**" instead of the correct "const T* const*".)
45 extern "C" {
46 using getrs_S = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
47                                const float* const*, int, const int*, float**,
48                                int, int*, int);
49 using getrs_D = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
50                                const double* const*, int, const int*, double**,
51                                int, int*, int);
52 using getrs_C = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
53                                const float2* const*, int, const int*, float2**,
54                                int, int*, int);
55 using getrs_Z = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
56                                const double2* const*, int, const int*,
57                                double2**, int, int*, int);
58 
59 using getri_S = cublasStatus_t(cublasContext*, int, const float* const*, int,
60                                const int*, float**, int, int*, int);
61 using getri_D = cublasStatus_t(cublasContext*, int, const double* const*, int,
62                                const int*, double**, int, int*, int);
63 using getri_C = cublasStatus_t(cublasContext*, int, const float2* const*, int,
64                                const int*, float2**, int, int*, int);
65 using getri_Z = cublasStatus_t(cublasContext*, int, const double2* const*, int,
66                                const int*, double2**, int, int*, int);
67 
68 using matinv_S = cublasStatus_t(cublasContext*, int, const float* const*, int,
69                                 float**, int, int*, int);
70 using matinv_D = cublasStatus_t(cublasContext*, int, const double* const*, int,
71                                 double**, int, int*, int);
72 using matinv_C = cublasStatus_t(cublasContext*, int, const float2* const*, int,
73                                 float2**, int, int*, int);
74 using matinv_Z = cublasStatus_t(cublasContext*, int, const double2* const*, int,
75                                 double2**, int, int*, int);
76 }
77 
78 namespace tensorflow {
79 namespace {
80 
81 using se::cuda::ScopedActivateExecutorContext;
82 
CopyHostToDevice(OpKernelContext * context,void * dst,const void * src,uint64 bytes)83 inline bool CopyHostToDevice(OpKernelContext* context, void* dst,
84                              const void* src, uint64 bytes) {
85   auto stream = context->op_device_context()->stream();
86   se::DeviceMemoryBase wrapped_dst(dst);
87   return stream->ThenMemcpy(&wrapped_dst, src, bytes).ok();
88 }
89 
90 // A set of initialized handles to the underlying Cuda libraries used by
91 // CudaSolver. We maintain one such set of handles per unique stream.
92 struct CudaSolverHandles {
CudaSolverHandlestensorflow::__anon93c492f70111::CudaSolverHandles93   explicit CudaSolverHandles(cudaStream_t stream) {
94     CHECK(cusolverDnCreate(&cusolver_dn_handle) == CUSOLVER_STATUS_SUCCESS)
95         << "Failed to create cuSolverDN instance.";
96     CHECK(cusolverDnSetStream(cusolver_dn_handle, stream) ==
97           CUSOLVER_STATUS_SUCCESS)
98         << "Failed to set cuSolverDN stream.";
99     CHECK(cublasCreate(&cublas_handle) == CUBLAS_STATUS_SUCCESS)
100         << "Failed to create cuBlas instance.";
101     CHECK(cublasSetStream(cublas_handle, stream) == CUBLAS_STATUS_SUCCESS)
102         << "Failed to set cuBlas stream.";
103   }
104 
~CudaSolverHandlestensorflow::__anon93c492f70111::CudaSolverHandles105   ~CudaSolverHandles() {
106     CHECK(cublasDestroy(cublas_handle) == CUBLAS_STATUS_SUCCESS)
107         << "Failed to destroy cuBlas instance.";
108     CHECK(cusolverDnDestroy(cusolver_dn_handle) == CUSOLVER_STATUS_SUCCESS)
109         << "Failed to destroy cuSolverDN instance.";
110   }
111   cublasHandle_t cublas_handle;
112   cusolverDnHandle_t cusolver_dn_handle;
113 };
114 
115 static mutex handle_map_mutex(LINKER_INITIALIZED);
116 
117 using HandleMap =
118     std::unordered_map<cudaStream_t, std::unique_ptr<CudaSolverHandles>>;
119 
120 // Returns a singleton map used for storing initialized handles for each unique
121 // cuda stream.
GetHandleMapSingleton()122 HandleMap* GetHandleMapSingleton() {
123   static HandleMap* cm = new HandleMap;
124   return cm;
125 }
126 
127 }  // namespace
128 
129 #define TF_RETURN_IF_CUSOLVER_ERROR(expr)                      \
130   do {                                                         \
131     auto status = (expr);                                      \
132     if (TF_PREDICT_FALSE(status != CUSOLVER_STATUS_SUCCESS)) { \
133       return errors::Internal(                                 \
134           __FILE__, ":", __LINE__,                             \
135           ": cuSolverDN call failed with status =", status);   \
136     }                                                          \
137   } while (0)
138 
139 #define TF_RETURN_IF_CUBLAS_ERROR(expr)                                  \
140   do {                                                                   \
141     auto status = (expr);                                                \
142     if (TF_PREDICT_FALSE(status != CUBLAS_STATUS_SUCCESS)) {             \
143       return errors::Internal(__FILE__, ":", __LINE__,                   \
144                               ": cuBlas call failed status = ", status); \
145     }                                                                    \
146   } while (0)
147 
CudaSolver(OpKernelContext * context)148 CudaSolver::CudaSolver(OpKernelContext* context) : context_(context) {
149   mutex_lock lock(handle_map_mutex);
150   const cudaStream_t* cu_stream_ptr = CHECK_NOTNULL(
151       reinterpret_cast<const cudaStream_t*>(context->op_device_context()
152                                                 ->stream()
153                                                 ->implementation()
154                                                 ->GpuStreamMemberHack()));
155   cuda_stream_ = *cu_stream_ptr;
156   HandleMap* handle_map = CHECK_NOTNULL(GetHandleMapSingleton());
157   auto it = handle_map->find(cuda_stream_);
158   if (it == handle_map->end()) {
159     LOG(INFO) << "Creating CudaSolver handles for stream " << cuda_stream_;
160     // Previously unseen Cuda stream. Initialize a set of Cuda solver library
161     // handles for it.
162     std::unique_ptr<CudaSolverHandles> new_handles(
163         new CudaSolverHandles(cuda_stream_));
164     it =
165         handle_map->insert(std::make_pair(cuda_stream_, std::move(new_handles)))
166             .first;
167   }
168   cusolver_dn_handle_ = it->second->cusolver_dn_handle;
169   cublas_handle_ = it->second->cublas_handle;
170 }
171 
~CudaSolver()172 CudaSolver::~CudaSolver() {
173   for (auto tensor_ref : scratch_tensor_refs_) {
174     tensor_ref.Unref();
175   }
176 }
177 
178 // static
CheckLapackInfoAndDeleteSolverAsync(std::unique_ptr<CudaSolver> solver,const std::vector<DeviceLapackInfo> & dev_lapack_infos,std::function<void (const Status &,const std::vector<HostLapackInfo> &)> info_checker_callback)179 void CudaSolver::CheckLapackInfoAndDeleteSolverAsync(
180     std::unique_ptr<CudaSolver> solver,
181     const std::vector<DeviceLapackInfo>& dev_lapack_infos,
182     std::function<void(const Status&, const std::vector<HostLapackInfo>&)>
183         info_checker_callback) {
184   CHECK(info_checker_callback != nullptr);
185   std::vector<HostLapackInfo> host_lapack_infos;
186   if (dev_lapack_infos.empty()) {
187     info_checker_callback(Status::OK(), host_lapack_infos);
188     return;
189   }
190 
191   // Launch memcpys to copy info back from the device to the host.
192   for (const auto& dev_lapack_info : dev_lapack_infos) {
193     bool success = true;
194     auto host_copy = dev_lapack_info.CopyToHost(&success);
195     OP_REQUIRES(
196         solver->context(), success,
197         errors::Internal(
198             "Failed to launch copy of dev_lapack_info to host, debug_info = ",
199             dev_lapack_info.debug_info()));
200     host_lapack_infos.push_back(std::move(host_copy));
201   }
202 
203   // This callback checks that all batch items in all calls were processed
204   // successfully and passes status to the info_checker_callback accordingly.
205   auto* stream = solver->context()->op_device_context()->stream();
206   auto wrapped_info_checker_callback =
207       [stream](
208           CudaSolver* solver,
209           std::function<void(const Status&, const std::vector<HostLapackInfo>&)>
210               info_checker_callback,
211           std::vector<HostLapackInfo> host_lapack_infos) {
212         ScopedActivateExecutorContext scoped_activation{stream->parent()};
213         Status status;
214         for (const auto& host_lapack_info : host_lapack_infos) {
215           for (int i = 0; i < host_lapack_info.size() && status.ok(); ++i) {
216             const int info_value = host_lapack_info(i);
217             if (info_value != 0) {
218               status = errors::InvalidArgument(
219                   "Got info = ", info_value, " for batch index ", i,
220                   ", expected info = 0. Debug_info = ",
221                   host_lapack_info.debug_info());
222             }
223           }
224           if (!status.ok()) {
225             break;
226           }
227         }
228         // Delete solver to release temp tensor refs.
229         delete solver;
230 
231         // Delegate further error checking to provided functor.
232         info_checker_callback(status, host_lapack_infos);
233       };
234   // Note: An std::function cannot have unique_ptr arguments (it must be copy
235   // constructible and therefore so must its arguments). Therefore, we release
236   // solver into a raw pointer to be deleted at the end of
237   // wrapped_info_checker_callback.
238   // Release ownership of solver. It will be deleted in the cb callback.
239   auto solver_raw_ptr = solver.release();
240   auto cb =
241       std::bind(wrapped_info_checker_callback, solver_raw_ptr,
242                 std::move(info_checker_callback), std::move(host_lapack_infos));
243 
244   solver_raw_ptr->context()
245       ->device()
246       ->tensorflow_gpu_device_info()
247       ->event_mgr->ThenExecute(stream, std::move(cb));
248 }
249 
250 // static
CheckLapackInfoAndDeleteSolverAsync(std::unique_ptr<CudaSolver> solver,const std::vector<DeviceLapackInfo> & dev_lapack_info,AsyncOpKernel::DoneCallback done)251 void CudaSolver::CheckLapackInfoAndDeleteSolverAsync(
252     std::unique_ptr<CudaSolver> solver,
253     const std::vector<DeviceLapackInfo>& dev_lapack_info,
254     AsyncOpKernel::DoneCallback done) {
255   OpKernelContext* context = solver->context();
256   auto wrapped_done = [context, done](
257                           const Status& status,
258                           const std::vector<HostLapackInfo>& /* unused */) {
259     if (done != nullptr) {
260       OP_REQUIRES_OK_ASYNC(context, status, done);
261       done();
262     } else {
263       OP_REQUIRES_OK(context, status);
264     }
265   };
266   CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_lapack_info,
267                                       wrapped_done);
268 }
269 
270 // Allocates a temporary tensor. The CudaSolver object maintains a
271 // TensorReference to the underlying Tensor to prevent it from being deallocated
272 // prematurely.
allocate_scoped_tensor(DataType type,const TensorShape & shape,Tensor * out_temp)273 Status CudaSolver::allocate_scoped_tensor(DataType type,
274                                           const TensorShape& shape,
275                                           Tensor* out_temp) {
276   const Status status = context_->allocate_temp(type, shape, out_temp);
277   if (status.ok()) {
278     scratch_tensor_refs_.emplace_back(*out_temp);
279   }
280   return status;
281 }
282 
forward_input_or_allocate_scoped_tensor(gtl::ArraySlice<int> candidate_input_indices,DataType type,const TensorShape & shape,Tensor * out_temp)283 Status CudaSolver::forward_input_or_allocate_scoped_tensor(
284     gtl::ArraySlice<int> candidate_input_indices, DataType type,
285     const TensorShape& shape, Tensor* out_temp) {
286   const Status status = context_->forward_input_or_allocate_temp(
287       candidate_input_indices, type, shape, out_temp);
288   if (status.ok()) {
289     scratch_tensor_refs_.emplace_back(*out_temp);
290   }
291   return status;
292 }
293 
294 // Macro that specializes a solver method for all 4 standard
295 // numeric types.
296 #define TF_CALL_LAPACK_TYPES(m) \
297   m(float, S) m(double, D) m(std::complex<float>, C) m(std::complex<double>, Z)
298 #define TF_CALL_LAPACK_TYPES_NO_COMPLEX(m) m(float, S) m(double, D)
299 
300 // Macros to construct cusolverDn method names.
301 #define DN_SOLVER_FN(method, type_prefix) cusolverDn##type_prefix##method
302 #define DN_SOLVER_NAME(method, type_prefix) "cusolverDn" #type_prefix #method
303 #define DN_BUFSIZE_FN(method, type_prefix) \
304   cusolverDn##type_prefix##method##_bufferSize
305 
306 // Macros to construct cublas method names.
307 #define BLAS_SOLVER_FN(method, type_prefix) cublas##type_prefix##method
308 #define BLAS_SOLVER_NAME(method, type_prefix) "cublas" #type_prefix #method
309 
310 //=============================================================================
311 // Wrappers of cuSolverDN computational methods begin here.
312 //
313 // WARNING to implementers: The function signatures listed in the online docs
314 // are sometimes inaccurate, e.g., are missing 'const' on pointers
315 // to immutable arguments, while the actual headers have them as expected.
316 // Check the actual declarations in the cusolver_api.h header file.
317 //
318 // NOTE: The cuSolver functions called below appear not to be threadsafe.
319 // so we put a global lock around the calls. Since these functions only put a
320 // kernel on the shared stream, it is not a big performance hit.
321 // TODO(rmlarsen): Investigate if the locking is still needed in Cuda 9.
322 //=============================================================================
323 
324 template <typename Scalar, typename SolverFnT>
GeamImpl(SolverFnT solver,cublasHandle_t cublas_handle,cublasOperation_t transa,cublasOperation_t transb,int m,int n,const Scalar * alpha,const Scalar * A,int lda,const Scalar * beta,const Scalar * B,int ldb,Scalar * C,int ldc)325 static inline Status GeamImpl(SolverFnT solver, cublasHandle_t cublas_handle,
326                               cublasOperation_t transa,
327                               cublasOperation_t transb, int m, int n,
328                               const Scalar* alpha, /* host or device pointer */
329                               const Scalar* A, int lda,
330                               const Scalar* beta, /* host or device pointer */
331                               const Scalar* B, int ldb, Scalar* C, int ldc) {
332   mutex_lock lock(handle_map_mutex);
333   using CudaScalar = typename CUDAComplexT<Scalar>::type;
334   TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, transa, transb, m, n,
335                                    reinterpret_cast<const CudaScalar*>(alpha),
336                                    reinterpret_cast<const CudaScalar*>(A), lda,
337                                    reinterpret_cast<const CudaScalar*>(beta),
338                                    reinterpret_cast<const CudaScalar*>(B), ldb,
339                                    reinterpret_cast<CudaScalar*>(C), ldc));
340   return Status::OK();
341 }
342 
343 #define GEAM_INSTANCE(Scalar, type_prefix)                                     \
344   template <>                                                                  \
345   Status CudaSolver::Geam<Scalar>(                                             \
346       cublasOperation_t transa, cublasOperation_t transb, int m, int n,        \
347       const Scalar* alpha, /* host or device pointer */                        \
348       const Scalar* A, int lda,                                                \
349       const Scalar* beta, /* host or device pointer */                         \
350       const Scalar* B, int ldb, Scalar* C, int ldc) const {                    \
351     return GeamImpl(BLAS_SOLVER_FN(geam, type_prefix), cublas_handle_, transa, \
352                     transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc);        \
353   }
354 
355 TF_CALL_LAPACK_TYPES(GEAM_INSTANCE);
356 
357 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
PotrfImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cublasFillMode_t uplo,int n,Scalar * A,int lda,int * dev_lapack_info)358 static inline Status PotrfImpl(BufSizeFnT bufsize, SolverFnT solver,
359                                CudaSolver* cuda_solver,
360                                OpKernelContext* context,
361                                cusolverDnHandle_t cusolver_dn_handle,
362                                cublasFillMode_t uplo, int n, Scalar* A, int lda,
363                                int* dev_lapack_info) {
364   mutex_lock lock(handle_map_mutex);
365   /* Get amount of workspace memory required. */
366   int lwork;
367   TF_RETURN_IF_CUSOLVER_ERROR(
368       bufsize(cusolver_dn_handle, uplo, n, CUDAComplex(A), lda, &lwork));
369   /* Allocate device memory for workspace. */
370   auto dev_workspace =
371       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
372   /* Launch the solver kernel. */
373   TF_RETURN_IF_CUSOLVER_ERROR(solver(
374       cusolver_dn_handle, uplo, n, CUDAComplex(A), lda,
375       CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));
376   return Status::OK();
377 }
378 
379 #define POTRF_INSTANCE(Scalar, type_prefix)                                  \
380   template <>                                                                \
381   Status CudaSolver::Potrf<Scalar>(cublasFillMode_t uplo, int n, Scalar* A,  \
382                                    int lda, int* dev_lapack_info) {          \
383     return PotrfImpl(DN_BUFSIZE_FN(potrf, type_prefix),                      \
384                      DN_SOLVER_FN(potrf, type_prefix), this, context_,       \
385                      cusolver_dn_handle_, uplo, n, A, lda, dev_lapack_info); \
386   }
387 
388 TF_CALL_LAPACK_TYPES(POTRF_INSTANCE);
389 
390 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
GetrfImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,int m,int n,Scalar * A,int lda,int * dev_pivots,int * dev_lapack_info)391 static inline Status GetrfImpl(BufSizeFnT bufsize, SolverFnT solver,
392                                CudaSolver* cuda_solver,
393                                OpKernelContext* context,
394                                cusolverDnHandle_t cusolver_dn_handle, int m,
395                                int n, Scalar* A, int lda, int* dev_pivots,
396                                int* dev_lapack_info) {
397   mutex_lock lock(handle_map_mutex);
398   /* Get amount of workspace memory required. */
399   int lwork;
400   TF_RETURN_IF_CUSOLVER_ERROR(
401       bufsize(cusolver_dn_handle, m, n, CUDAComplex(A), lda, &lwork));
402   /* Allocate device memory for workspace. */
403   auto dev_workspace =
404       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
405   /* Launch the solver kernel. */
406   TF_RETURN_IF_CUSOLVER_ERROR(solver(
407       cusolver_dn_handle, m, n, CUDAComplex(A), lda,
408       CUDAComplex(dev_workspace.mutable_data()), dev_pivots, dev_lapack_info));
409   return Status::OK();
410 }
411 
412 #define GETRF_INSTANCE(Scalar, type_prefix)                                 \
413   template <>                                                               \
414   Status CudaSolver::Getrf<Scalar>(int m, int n, Scalar* A, int lda,        \
415                                    int* dev_pivots, int* dev_lapack_info) { \
416     return GetrfImpl(DN_BUFSIZE_FN(getrf, type_prefix),                     \
417                      DN_SOLVER_FN(getrf, type_prefix), this, context_,      \
418                      cusolver_dn_handle_, m, n, A, lda, dev_pivots,         \
419                      dev_lapack_info);                                      \
420   }
421 
422 TF_CALL_LAPACK_TYPES(GETRF_INSTANCE);
423 
424 template <typename Scalar, typename SolverFnT>
GetrsImpl(SolverFnT solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cublasOperation_t trans,int n,int nrhs,const Scalar * A,int lda,const int * pivots,Scalar * B,int ldb,int * dev_lapack_info)425 static inline Status GetrsImpl(SolverFnT solver, OpKernelContext* context,
426                                cusolverDnHandle_t cusolver_dn_handle,
427                                cublasOperation_t trans, int n, int nrhs,
428                                const Scalar* A, int lda, const int* pivots,
429                                Scalar* B, int ldb, int* dev_lapack_info) {
430   mutex_lock lock(handle_map_mutex);
431   /* Launch the solver kernel. */
432   TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, trans, n, nrhs,
433                                      CUDAComplex(A), lda, pivots,
434                                      CUDAComplex(B), ldb, dev_lapack_info));
435   return Status::OK();
436 }
437 
438 #define GETRS_INSTANCE(Scalar, type_prefix)                                  \
439   template <>                                                                \
440   Status CudaSolver::Getrs<Scalar>(                                          \
441       cublasOperation_t trans, int n, int nrhs, const Scalar* A, int lda,    \
442       const int* pivots, Scalar* B, int ldb, int* dev_lapack_info) const {   \
443     return GetrsImpl(DN_SOLVER_FN(getrs, type_prefix), context_,             \
444                      cusolver_dn_handle_, trans, n, nrhs, A, lda, pivots, B, \
445                      ldb, dev_lapack_info);                                  \
446   }
447 
448 TF_CALL_LAPACK_TYPES(GETRS_INSTANCE);
449 
450 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
GeqrfImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,int m,int n,Scalar * A,int lda,Scalar * tau,int * dev_lapack_info)451 static inline Status GeqrfImpl(BufSizeFnT bufsize, SolverFnT solver,
452                                CudaSolver* cuda_solver,
453                                OpKernelContext* context,
454                                cusolverDnHandle_t cusolver_dn_handle, int m,
455                                int n, Scalar* A, int lda, Scalar* tau,
456                                int* dev_lapack_info) {
457   mutex_lock lock(handle_map_mutex);
458   /* Get amount of workspace memory required. */
459   int lwork;
460   TF_RETURN_IF_CUSOLVER_ERROR(
461       bufsize(cusolver_dn_handle, m, n, CUDAComplex(A), lda, &lwork));
462   /* Allocate device memory for workspace. */
463   auto dev_workspace =
464       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
465   /* Launch the solver kernel. */
466   TF_RETURN_IF_CUSOLVER_ERROR(solver(
467       cusolver_dn_handle, m, n, CUDAComplex(A), lda, CUDAComplex(tau),
468       CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));
469   return Status::OK();
470 }
471 
472 #define GEQRF_INSTANCE(Scalar, type_prefix)                                    \
473   template <>                                                                  \
474   Status CudaSolver::Geqrf<Scalar>(int m, int n, Scalar* A, int lda,           \
475                                    Scalar* tau, int* dev_lapack_info) {        \
476     return GeqrfImpl(DN_BUFSIZE_FN(geqrf, type_prefix),                        \
477                      DN_SOLVER_FN(geqrf, type_prefix), this, context_,         \
478                      cusolver_dn_handle_, m, n, A, lda, tau, dev_lapack_info); \
479   }
480 
481 TF_CALL_LAPACK_TYPES(GEQRF_INSTANCE);
482 
483 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
UnmqrImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cublasSideMode_t side,cublasOperation_t trans,int m,int n,int k,const Scalar * dev_a,int lda,const Scalar * dev_tau,Scalar * dev_c,int ldc,int * dev_lapack_info)484 static inline Status UnmqrImpl(BufSizeFnT bufsize, SolverFnT solver,
485                                CudaSolver* cuda_solver,
486                                OpKernelContext* context,
487                                cusolverDnHandle_t cusolver_dn_handle,
488                                cublasSideMode_t side, cublasOperation_t trans,
489                                int m, int n, int k, const Scalar* dev_a,
490                                int lda, const Scalar* dev_tau, Scalar* dev_c,
491                                int ldc, int* dev_lapack_info) {
492   mutex_lock lock(handle_map_mutex);
493   /* Get amount of workspace memory required. */
494   int lwork;
495   TF_RETURN_IF_CUSOLVER_ERROR(
496       bufsize(cusolver_dn_handle, side, trans, m, n, k, CUDAComplex(dev_a), lda,
497               CUDAComplex(dev_tau), CUDAComplex(dev_c), ldc, &lwork));
498   /* Allocate device memory for workspace. */
499   auto dev_workspace =
500       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
501   /* Launch the solver kernel. */
502   TF_RETURN_IF_CUSOLVER_ERROR(solver(
503       cusolver_dn_handle, side, trans, m, n, k, CUDAComplex(dev_a), lda,
504       CUDAComplex(dev_tau), CUDAComplex(dev_c), ldc,
505       CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));
506   return Status::OK();
507 }
508 
509 // Unfortunately the LAPACK function name differs for the real and complex case
510 // (complex ones are prefixed with "UN" for "unitary"), so we instantiate each
511 // one separately.
512 #define UNMQR_INSTANCE(Scalar, function_prefix, type_prefix)                  \
513   template <>                                                                 \
514   Status CudaSolver::Unmqr(cublasSideMode_t side, cublasOperation_t trans,    \
515                            int m, int n, int k, const Scalar* dev_a, int lda, \
516                            const Scalar* dev_tau, Scalar* dev_c, int ldc,     \
517                            int* dev_lapack_info) {                            \
518     return UnmqrImpl(DN_BUFSIZE_FN(function_prefix##mqr, type_prefix),        \
519                      DN_SOLVER_FN(function_prefix##mqr, type_prefix), this,   \
520                      context_, cusolver_dn_handle_, side, trans, m, n, k,     \
521                      dev_a, lda, dev_tau, dev_c, ldc, dev_lapack_info);       \
522   }
523 
524 UNMQR_INSTANCE(float, or, S);
525 UNMQR_INSTANCE(double, or, D);
526 UNMQR_INSTANCE(complex64, un, C);
527 UNMQR_INSTANCE(complex128, un, Z);
528 
529 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
UngqrImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,int m,int n,int k,Scalar * dev_a,int lda,const Scalar * dev_tau,int * dev_lapack_info)530 static inline Status UngqrImpl(BufSizeFnT bufsize, SolverFnT solver,
531                                CudaSolver* cuda_solver,
532                                OpKernelContext* context,
533                                cusolverDnHandle_t cusolver_dn_handle, int m,
534                                int n, int k, Scalar* dev_a, int lda,
535                                const Scalar* dev_tau, int* dev_lapack_info) {
536   mutex_lock lock(handle_map_mutex);
537   /* Get amount of workspace memory required. */
538   int lwork;
539   TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, k,
540                                       CUDAComplex(dev_a), lda,
541                                       CUDAComplex(dev_tau), &lwork));
542   /* Allocate device memory for workspace. */
543   auto dev_workspace =
544       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
545   /* Launch the solver kernel. */
546   TF_RETURN_IF_CUSOLVER_ERROR(
547       solver(cusolver_dn_handle, m, n, k, CUDAComplex(dev_a), lda,
548              CUDAComplex(dev_tau), CUDAComplex(dev_workspace.mutable_data()),
549              lwork, dev_lapack_info));
550   return Status::OK();
551 }
552 
553 #define UNGQR_INSTANCE(Scalar, function_prefix, type_prefix)                \
554   template <>                                                               \
555   Status CudaSolver::Ungqr(int m, int n, int k, Scalar* dev_a, int lda,     \
556                            const Scalar* dev_tau, int* dev_lapack_info) {   \
557     return UngqrImpl(DN_BUFSIZE_FN(function_prefix##gqr, type_prefix),      \
558                      DN_SOLVER_FN(function_prefix##gqr, type_prefix), this, \
559                      context_, cusolver_dn_handle_, m, n, k, dev_a, lda,    \
560                      dev_tau, dev_lapack_info);                             \
561   }
562 
563 UNGQR_INSTANCE(float, or, S);
564 UNGQR_INSTANCE(double, or, D);
565 UNGQR_INSTANCE(complex64, un, C);
566 UNGQR_INSTANCE(complex128, un, Z);
567 
568 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
HeevdImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,Scalar * dev_A,int lda,typename Eigen::NumTraits<Scalar>::Real * dev_W,int * dev_lapack_info)569 static inline Status HeevdImpl(BufSizeFnT bufsize, SolverFnT solver,
570                                CudaSolver* cuda_solver,
571                                OpKernelContext* context,
572                                cusolverDnHandle_t cusolver_dn_handle,
573                                cusolverEigMode_t jobz, cublasFillMode_t uplo,
574                                int n, Scalar* dev_A, int lda,
575                                typename Eigen::NumTraits<Scalar>::Real* dev_W,
576                                int* dev_lapack_info) {
577   mutex_lock lock(handle_map_mutex);
578   /* Get amount of workspace memory required. */
579   int lwork;
580   TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, jobz, uplo, n,
581                                       CUDAComplex(dev_A), lda,
582                                       CUDAComplex(dev_W), &lwork));
583   /* Allocate device memory for workspace. */
584   auto dev_workspace =
585       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
586   /* Launch the solver kernel. */
587   TF_RETURN_IF_CUSOLVER_ERROR(
588       solver(cusolver_dn_handle, jobz, uplo, n, CUDAComplex(dev_A), lda,
589              CUDAComplex(dev_W), CUDAComplex(dev_workspace.mutable_data()),
590              lwork, dev_lapack_info));
591   return Status::OK();
592 }
593 
594 #define HEEVD_INSTANCE(Scalar, function_prefix, type_prefix)                   \
595   template <>                                                                  \
596   Status CudaSolver::Heevd(cusolverEigMode_t jobz, cublasFillMode_t uplo,      \
597                            int n, Scalar* dev_A, int lda,                      \
598                            typename Eigen::NumTraits<Scalar>::Real* dev_W,     \
599                            int* dev_lapack_info) {                             \
600     return HeevdImpl(DN_BUFSIZE_FN(function_prefix##evd, type_prefix),         \
601                      DN_SOLVER_FN(function_prefix##evd, type_prefix), this,    \
602                      context_, cusolver_dn_handle_, jobz, uplo, n, dev_A, lda, \
603                      dev_W, dev_lapack_info);                                  \
604   }
605 
606 HEEVD_INSTANCE(float, sy, S);
607 HEEVD_INSTANCE(double, sy, D);
608 HEEVD_INSTANCE(complex64, he, C);
609 HEEVD_INSTANCE(complex128, he, Z);
610 
611 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
GesvdImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,signed char jobu,signed char jobvt,int m,int n,Scalar * A,int lda,Scalar * S,Scalar * U,int ldu,Scalar * VT,int ldvt,int * dev_lapack_info)612 static inline Status GesvdImpl(
613     BufSizeFnT bufsize, SolverFnT solver, CudaSolver* cuda_solver,
614     OpKernelContext* context, cusolverDnHandle_t cusolver_dn_handle,
615     signed char jobu, signed char jobvt, int m, int n, Scalar* A, int lda,
616     Scalar* S, Scalar* U, int ldu, Scalar* VT, int ldvt, int* dev_lapack_info) {
617   mutex_lock lock(handle_map_mutex);
618   /* Get amount of workspace memory required. */
619   int lwork;
620   TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, &lwork));
621   /* Allocate device memory for workspace. */
622   auto dev_workspace =
623       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
624   TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, jobu, jobvt, m, n,
625                                      CUDAComplex(A), lda, S, CUDAComplex(U),
626                                      ldu, CUDAComplex(VT), ldvt,
627                                      CUDAComplex(dev_workspace.mutable_data()),
628                                      lwork, nullptr, dev_lapack_info));
629   return Status::OK();
630 }
631 
632 #define GESVD_INSTANCE(Scalar, type_prefix)                              \
633   template <>                                                            \
634   Status CudaSolver::Gesvd<Scalar>(                                      \
635       signed char jobu, signed char jobvt, int m, int n, Scalar* dev_A,  \
636       int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT,    \
637       int ldvt, int* dev_lapack_info) {                                  \
638     return GesvdImpl(DN_BUFSIZE_FN(gesvd, type_prefix),                  \
639                      DN_SOLVER_FN(gesvd, type_prefix), this, context_,   \
640                      cusolver_dn_handle_, jobu, jobvt, m, n, dev_A, lda, \
641                      dev_S, dev_U, ldu, dev_VT, ldvt, dev_lapack_info);  \
642   }
643 
644 TF_CALL_LAPACK_TYPES_NO_COMPLEX(GESVD_INSTANCE);
645 
646 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
GesvdjBatchedImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cusolverEigMode_t jobz,int m,int n,Scalar * A,int lda,Scalar * S,Scalar * U,int ldu,Scalar * V,int ldv,int * dev_lapack_info,int batch_size)647 static inline Status GesvdjBatchedImpl(BufSizeFnT bufsize, SolverFnT solver,
648                                        CudaSolver* cuda_solver,
649                                        OpKernelContext* context,
650                                        cusolverDnHandle_t cusolver_dn_handle,
651                                        cusolverEigMode_t jobz, int m, int n,
652                                        Scalar* A, int lda, Scalar* S, Scalar* U,
653                                        int ldu, Scalar* V, int ldv,
654                                        int* dev_lapack_info, int batch_size) {
655   mutex_lock lock(handle_map_mutex);
656   /* Get amount of workspace memory required. */
657   int lwork;
658   /* Default parameters for gesvdj and gesvdjBatched. */
659   gesvdjInfo_t svdj_info;
660   TF_RETURN_IF_CUSOLVER_ERROR(cusolverDnCreateGesvdjInfo(&svdj_info));
661   TF_RETURN_IF_CUSOLVER_ERROR(bufsize(
662       cusolver_dn_handle, jobz, m, n, CUDAComplex(A), lda, S, CUDAComplex(U),
663       ldu, CUDAComplex(V), ldv, &lwork, svdj_info, batch_size));
664   /* Allocate device memory for workspace. */
665   auto dev_workspace =
666       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
667   TF_RETURN_IF_CUSOLVER_ERROR(solver(
668       cusolver_dn_handle, jobz, m, n, CUDAComplex(A), lda, S, CUDAComplex(U),
669       ldu, CUDAComplex(V), ldv, CUDAComplex(dev_workspace.mutable_data()),
670       lwork, dev_lapack_info, svdj_info, batch_size));
671   TF_RETURN_IF_CUSOLVER_ERROR(cusolverDnDestroyGesvdjInfo(svdj_info));
672   return Status::OK();
673 }
674 
675 #define GESVDJBATCHED_INSTANCE(Scalar, type_prefix)                            \
676   template <>                                                                  \
677   Status CudaSolver::GesvdjBatched<Scalar>(                                    \
678       cusolverEigMode_t jobz, int m, int n, Scalar* dev_A, int lda,            \
679       Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_V, int ldv,           \
680       int* dev_lapack_info, int batch_size) {                                  \
681     return GesvdjBatchedImpl(DN_BUFSIZE_FN(gesvdjBatched, type_prefix),        \
682                              DN_SOLVER_FN(gesvdjBatched, type_prefix), this,   \
683                              context_, cusolver_dn_handle_, jobz, m, n, dev_A, \
684                              lda, dev_S, dev_U, ldu, dev_V, ldv,               \
685                              dev_lapack_info, batch_size);                     \
686   }
687 
688 TF_CALL_LAPACK_TYPES_NO_COMPLEX(GESVDJBATCHED_INSTANCE);
689 
690 //=============================================================================
691 // Wrappers of cuBlas computational methods begin here.
692 //
693 // WARNING to implementers: The function signatures listed in the online docs
694 // are sometimes inaccurate, e.g., are missing 'const' on pointers
695 // to immutable arguments, while the actual headers have them as expected.
696 // Check the actual declarations in the cublas_api.h header file.
697 //=============================================================================
698 template <typename Scalar, typename SolverFnT>
GetrfBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,int n,const Scalar * const host_a_dev_ptrs[],int lda,int * dev_pivots,DeviceLapackInfo * dev_lapack_info,int batch_size)699 static inline Status GetrfBatchedImpl(SolverFnT solver, CudaSolver* cuda_solver,
700                                       OpKernelContext* context,
701                                       cublasHandle_t cublas_handle, int n,
702                                       const Scalar* const host_a_dev_ptrs[],
703                                       int lda, int* dev_pivots,
704                                       DeviceLapackInfo* dev_lapack_info,
705                                       int batch_size) {
706   mutex_lock lock(handle_map_mutex);
707   using CudaScalar = typename CUDAComplexT<Scalar>::type;
708   ScratchSpace<uint8> dev_a_dev_ptrs =
709       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
710                                           /* on_host */ false);
711   if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
712                         host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
713     return errors::Internal("GetrfBatched: failed to copy pointers to device");
714   }
715   TF_RETURN_IF_CUBLAS_ERROR(
716       solver(cublas_handle, n,
717              reinterpret_cast<CudaScalar**>(dev_a_dev_ptrs.mutable_data()), lda,
718              dev_pivots, dev_lapack_info->mutable_data(), batch_size));
719   return Status::OK();
720 }
721 
722 #define GETRF_BATCHED_INSTANCE(Scalar, type_prefix)                            \
723   template <>                                                                  \
724   Status CudaSolver::GetrfBatched(                                             \
725       int n, const Scalar* const host_a_dev_ptrs[], int lda, int* dev_pivots,  \
726       DeviceLapackInfo* dev_lapack_info, int batch_size) {                     \
727     return GetrfBatchedImpl(BLAS_SOLVER_FN(getrfBatched, type_prefix), this,   \
728                             context_, cublas_handle_, n, host_a_dev_ptrs, lda, \
729                             dev_pivots, dev_lapack_info, batch_size);          \
730   }
731 
732 TF_CALL_LAPACK_TYPES(GETRF_BATCHED_INSTANCE);
733 
734 template <typename Scalar, typename SolverFnT>
GetrsBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,cublasOperation_t trans,int n,int nrhs,const Scalar * const host_a_dev_ptrs[],int lda,const int * dev_pivots,const Scalar * const host_b_dev_ptrs[],int ldb,int * host_lapack_info,int batch_size)735 static inline Status GetrsBatchedImpl(
736     SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
737     cublasHandle_t cublas_handle, cublasOperation_t trans, int n, int nrhs,
738     const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots,
739     const Scalar* const host_b_dev_ptrs[], int ldb, int* host_lapack_info,
740     int batch_size) {
741   mutex_lock lock(handle_map_mutex);
742   using CudaScalar = typename CUDAComplexT<Scalar>::type;
743   ScratchSpace<uint8> dev_a_dev_ptrs =
744       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
745                                           /* on_host */ false);
746   ScratchSpace<uint8> dev_b_dev_ptrs =
747       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
748                                           /* on_host */ false);
749   if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
750                         host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
751     return errors::Internal("GetrsBatched: failed to copy pointers to device");
752   }
753   if (!CopyHostToDevice(context, dev_b_dev_ptrs.mutable_data() /* dest */,
754                         host_b_dev_ptrs /* source */, dev_b_dev_ptrs.bytes())) {
755     return errors::Internal("GetrsBatched: failed to copy pointers to device");
756   }
757   TF_RETURN_IF_CUBLAS_ERROR(solver(
758       cublas_handle, trans, n, nrhs,
759       reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()), lda,
760       dev_pivots, reinterpret_cast<CudaScalar**>(dev_b_dev_ptrs.mutable_data()),
761       ldb, host_lapack_info, batch_size));
762   return Status::OK();
763 }
764 
765 #define GETRS_BATCHED_INSTANCE(Scalar, type_prefix)                            \
766   template <>                                                                  \
767   Status CudaSolver::GetrsBatched(                                             \
768       cublasOperation_t trans, int n, int nrhs,                                \
769       const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots,   \
770       const Scalar* const host_b_dev_ptrs[], int ldb, int* host_lapack_info,   \
771       int batch_size) {                                                        \
772     return GetrsBatchedImpl(reinterpret_cast<getrs_##type_prefix*>(            \
773                                 BLAS_SOLVER_FN(getrsBatched, type_prefix)),    \
774                             this, context_, cublas_handle_, trans, n, nrhs,    \
775                             host_a_dev_ptrs, lda, dev_pivots, host_b_dev_ptrs, \
776                             ldb, host_lapack_info, batch_size);                \
777   }
778 
779 TF_CALL_LAPACK_TYPES(GETRS_BATCHED_INSTANCE);
780 
781 template <typename Scalar, typename SolverFnT>
GetriBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,int n,const Scalar * const host_a_dev_ptrs[],int lda,const int * dev_pivots,const Scalar * const host_a_inv_dev_ptrs[],int ldainv,DeviceLapackInfo * dev_lapack_info,int batch_size)782 static inline Status GetriBatchedImpl(
783     SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
784     cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[],
785     int lda, const int* dev_pivots, const Scalar* const host_a_inv_dev_ptrs[],
786     int ldainv, DeviceLapackInfo* dev_lapack_info, int batch_size) {
787   mutex_lock lock(handle_map_mutex);
788   using CudaScalar = typename CUDAComplexT<Scalar>::type;
789   ScratchSpace<uint8> dev_a_dev_ptrs =
790       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
791                                           /* on_host */ false);
792   ScratchSpace<uint8> dev_a_inv_dev_ptrs = cuda_solver->GetScratchSpace<uint8>(
793       sizeof(CudaScalar*) * batch_size, "", /* on_host */ false);
794   if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
795                         host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes()) ||
796       !CopyHostToDevice(context, dev_a_inv_dev_ptrs.mutable_data(),
797                         host_a_inv_dev_ptrs, dev_a_inv_dev_ptrs.bytes())) {
798     return errors::Internal("GetriBatched: failed to copy pointers to device");
799   }
800   TF_RETURN_IF_CUBLAS_ERROR(
801       solver(cublas_handle, n,
802              reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()),
803              lda, dev_pivots,
804              reinterpret_cast<CudaScalar**>(dev_a_inv_dev_ptrs.mutable_data()),
805              ldainv, dev_lapack_info->mutable_data(), batch_size));
806   return Status::OK();
807 }
808 
809 #define GETRI_BATCHED_INSTANCE(Scalar, type_prefix)                          \
810   template <>                                                                \
811   Status CudaSolver::GetriBatched(                                           \
812       int n, const Scalar* const host_a_dev_ptrs[], int lda,                 \
813       const int* dev_pivots, const Scalar* const host_a_inv_dev_ptrs[],      \
814       int ldainv, DeviceLapackInfo* dev_lapack_info, int batch_size) {       \
815     return GetriBatchedImpl(                                                 \
816         reinterpret_cast<getri_##type_prefix*>(                              \
817             BLAS_SOLVER_FN(getriBatched, type_prefix)),                      \
818         this, context_, cublas_handle_, n, host_a_dev_ptrs, lda, dev_pivots, \
819         host_a_inv_dev_ptrs, ldainv, dev_lapack_info, batch_size);           \
820   }
821 
822 TF_CALL_LAPACK_TYPES(GETRI_BATCHED_INSTANCE);
823 
824 template <typename Scalar, typename SolverFnT>
MatInvBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,int n,const Scalar * const host_a_dev_ptrs[],int lda,const Scalar * const host_a_inv_dev_ptrs[],int ldainv,DeviceLapackInfo * dev_lapack_info,int batch_size)825 static inline Status MatInvBatchedImpl(
826     SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
827     cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[],
828     int lda, const Scalar* const host_a_inv_dev_ptrs[], int ldainv,
829     DeviceLapackInfo* dev_lapack_info, int batch_size) {
830   mutex_lock lock(handle_map_mutex);
831   using CudaScalar = typename CUDAComplexT<Scalar>::type;
832   ScratchSpace<uint8> dev_a_dev_ptrs =
833       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
834                                           /* on_host */ false);
835   ScratchSpace<uint8> dev_a_inv_dev_ptrs = cuda_solver->GetScratchSpace<uint8>(
836       sizeof(CudaScalar*) * batch_size, "", /* on_host */ false);
837   if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
838                         host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes()) ||
839       !CopyHostToDevice(context, dev_a_inv_dev_ptrs.mutable_data(),
840                         host_a_inv_dev_ptrs, dev_a_inv_dev_ptrs.bytes())) {
841     return errors::Internal("MatInvBatched: failed to copy pointers to device");
842   }
843   TF_RETURN_IF_CUBLAS_ERROR(solver(
844       cublas_handle, n,
845       reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()), lda,
846       reinterpret_cast<CudaScalar**>(dev_a_inv_dev_ptrs.mutable_data()), ldainv,
847       dev_lapack_info->mutable_data(), batch_size));
848   return Status::OK();
849 }
850 
851 #define MATINV_BATCHED_INSTANCE(Scalar, type_prefix)                          \
852   template <>                                                                 \
853   Status CudaSolver::MatInvBatched(                                           \
854       int n, const Scalar* const host_a_dev_ptrs[], int lda,                  \
855       const Scalar* const host_a_inv_dev_ptrs[], int ldainv,                  \
856       DeviceLapackInfo* dev_lapack_info, int batch_size) {                    \
857     return MatInvBatchedImpl(reinterpret_cast<matinv_##type_prefix*>(         \
858                                  BLAS_SOLVER_FN(matinvBatched, type_prefix)), \
859                              this, context_, cublas_handle_, n,               \
860                              host_a_dev_ptrs, lda, host_a_inv_dev_ptrs,       \
861                              ldainv, dev_lapack_info, batch_size);            \
862   }
863 
864 TF_CALL_LAPACK_TYPES(MATINV_BATCHED_INSTANCE);
865 
866 }  // namespace tensorflow
867 
868 #endif  // GOOGLE_CUDA
869