• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3    Licensed under the Apache License, Version 2.0 (the "License");
4    you may not use this file except in compliance with the License.
5    You may obtain a copy of the License at
6 
7    http://www.apache.org/licenses/LICENSE-2.0
8 
9    Unless required by applicable law or agreed to in writing, software
10    distributed under the License is distributed on an "AS IS" BASIS,
11    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12    See the License for the specific language governing permissions and
13    limitations under the License.
14    ==============================================================================
15 */
16 #ifdef GOOGLE_CUDA
17 #include "tensorflow/core/kernels/cuda_solvers.h"
18 
19 #include <chrono>
20 #include <complex>
21 #include <unordered_map>
22 #include <vector>
23 
24 #include "third_party/gpus/cuda/include/cublas_v2.h"
25 #include "third_party/gpus/cuda/include/cusolverDn.h"
26 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/lib/core/blocking_counter.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/lib/core/stringpiece.h"
32 #include "tensorflow/core/lib/gtl/inlined_vector.h"
33 #include "tensorflow/core/platform/mutex.h"
34 #include "tensorflow/core/platform/stream_executor.h"
35 #include "tensorflow/core/platform/types.h"
36 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
37 
38 // The CUDA cublas_api.h API contains const-correctness errors. Instead of
39 // casting away constness on our data, we instead reinterpret the CuBLAS
40 // functions as what they were clearly meant to be, and thus we can call
41 // the functions naturally.
42 //
43 // (The error is that input-only arrays are bound to parameter types
44 // "const T**" instead of the correct "const T* const*".)
45 extern "C" {
46 using getrs_S = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
47                                const float* const*, int, const int*, float**,
48                                int, int*, int);
49 using getrs_D = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
50                                const double* const*, int, const int*, double**,
51                                int, int*, int);
52 using getrs_C = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
53                                const float2* const*, int, const int*, float2**,
54                                int, int*, int);
55 using getrs_Z = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
56                                const double2* const*, int, const int*,
57                                double2**, int, int*, int);
58 
59 using getri_S = cublasStatus_t(cublasContext*, int, const float* const*, int,
60                                const int*, float**, int, int*, int);
61 using getri_D = cublasStatus_t(cublasContext*, int, const double* const*, int,
62                                const int*, double**, int, int*, int);
63 using getri_C = cublasStatus_t(cublasContext*, int, const float2* const*, int,
64                                const int*, float2**, int, int*, int);
65 using getri_Z = cublasStatus_t(cublasContext*, int, const double2* const*, int,
66                                const int*, double2**, int, int*, int);
67 
68 using matinv_S = cublasStatus_t(cublasContext*, int, const float* const*, int,
69                                 float**, int, int*, int);
70 using matinv_D = cublasStatus_t(cublasContext*, int, const double* const*, int,
71                                 double**, int, int*, int);
72 using matinv_C = cublasStatus_t(cublasContext*, int, const float2* const*, int,
73                                 float2**, int, int*, int);
74 using matinv_Z = cublasStatus_t(cublasContext*, int, const double2* const*, int,
75                                 double2**, int, int*, int);
76 }
77 
78 namespace tensorflow {
79 namespace {
80 
81 using se::cuda::ScopedActivateExecutorContext;
82 
CopyHostToDevice(OpKernelContext * context,void * dst,const void * src,uint64 bytes)83 inline bool CopyHostToDevice(OpKernelContext* context, void* dst,
84                              const void* src, uint64 bytes) {
85   auto stream = context->op_device_context()->stream();
86   se::DeviceMemoryBase wrapped_dst(dst);
87   return stream->ThenMemcpy(&wrapped_dst, src, bytes).ok();
88 }
89 
90 // A set of initialized handles to the underlying Cuda libraries used by
91 // CudaSolver. We maintain one such set of handles per unique stream.
92 struct CudaSolverHandles {
CudaSolverHandlestensorflow::__anona8fbe2f90111::CudaSolverHandles93   explicit CudaSolverHandles(cudaStream_t stream) {
94     CHECK(cusolverDnCreate(&cusolver_dn_handle) == CUSOLVER_STATUS_SUCCESS)
95         << "Failed to create cuSolverDN instance.";
96     CHECK(cusolverDnSetStream(cusolver_dn_handle, stream) ==
97           CUSOLVER_STATUS_SUCCESS)
98         << "Failed to set cuSolverDN stream.";
99     CHECK(cublasCreate(&cublas_handle) == CUBLAS_STATUS_SUCCESS)
100         << "Failed to create cuBlas instance.";
101     CHECK(cublasSetStream(cublas_handle, stream) == CUBLAS_STATUS_SUCCESS)
102         << "Failed to set cuBlas stream.";
103   }
104 
~CudaSolverHandlestensorflow::__anona8fbe2f90111::CudaSolverHandles105   ~CudaSolverHandles() {
106     CHECK(cublasDestroy(cublas_handle) == CUBLAS_STATUS_SUCCESS)
107         << "Failed to destroy cuBlas instance.";
108     CHECK(cusolverDnDestroy(cusolver_dn_handle) == CUSOLVER_STATUS_SUCCESS)
109         << "Failed to destroy cuSolverDN instance.";
110   }
111   cublasHandle_t cublas_handle;
112   cusolverDnHandle_t cusolver_dn_handle;
113 };
114 
115 static mutex handle_map_mutex(LINKER_INITIALIZED);
116 
117 using HandleMap =
118     std::unordered_map<cudaStream_t, std::unique_ptr<CudaSolverHandles>>;
119 
120 // Returns a singleton map used for storing initialized handles for each unique
121 // cuda stream.
GetHandleMapSingleton()122 HandleMap* GetHandleMapSingleton() {
123   static HandleMap* cm = new HandleMap;
124   return cm;
125 }
126 
127 }  // namespace
128 
129 #define TF_RETURN_IF_CUSOLVER_ERROR(expr)                      \
130   do {                                                         \
131     auto status = (expr);                                      \
132     if (TF_PREDICT_FALSE(status != CUSOLVER_STATUS_SUCCESS)) { \
133       return errors::Internal(                                 \
134           __FILE__, ":", __LINE__,                             \
135           ": cuSolverDN call failed with status =", status);   \
136     }                                                          \
137   } while (0)
138 
139 #define TF_RETURN_IF_CUBLAS_ERROR(expr)                                  \
140   do {                                                                   \
141     auto status = (expr);                                                \
142     if (TF_PREDICT_FALSE(status != CUBLAS_STATUS_SUCCESS)) {             \
143       return errors::Internal(__FILE__, ":", __LINE__,                   \
144                               ": cuBlas call failed status = ", status); \
145     }                                                                    \
146   } while (0)
147 
CudaSolver(OpKernelContext * context)148 CudaSolver::CudaSolver(OpKernelContext* context) : context_(context) {
149   mutex_lock lock(handle_map_mutex);
150   const cudaStream_t* cu_stream_ptr = CHECK_NOTNULL(
151       reinterpret_cast<const cudaStream_t*>(context->op_device_context()
152                                                 ->stream()
153                                                 ->implementation()
154                                                 ->GpuStreamMemberHack()));
155   cuda_stream_ = *cu_stream_ptr;
156   HandleMap* handle_map = CHECK_NOTNULL(GetHandleMapSingleton());
157   auto it = handle_map->find(cuda_stream_);
158   if (it == handle_map->end()) {
159     LOG(INFO) << "Creating CudaSolver handles for stream " << cuda_stream_;
160     // Previously unseen Cuda stream. Initialize a set of Cuda solver library
161     // handles for it.
162     std::unique_ptr<CudaSolverHandles> new_handles(
163         new CudaSolverHandles(cuda_stream_));
164     it =
165         handle_map->insert(std::make_pair(cuda_stream_, std::move(new_handles)))
166             .first;
167   }
168   cusolver_dn_handle_ = it->second->cusolver_dn_handle;
169   cublas_handle_ = it->second->cublas_handle;
170 }
171 
~CudaSolver()172 CudaSolver::~CudaSolver() {
173   for (auto tensor_ref : scratch_tensor_refs_) {
174     tensor_ref.Unref();
175   }
176 }
177 
178 // static
CheckLapackInfoAndDeleteSolverAsync(std::unique_ptr<CudaSolver> solver,const std::vector<DeviceLapackInfo> & dev_lapack_infos,std::function<void (const Status &,const std::vector<HostLapackInfo> &)> info_checker_callback)179 void CudaSolver::CheckLapackInfoAndDeleteSolverAsync(
180     std::unique_ptr<CudaSolver> solver,
181     const std::vector<DeviceLapackInfo>& dev_lapack_infos,
182     std::function<void(const Status&, const std::vector<HostLapackInfo>&)>
183         info_checker_callback) {
184   CHECK(info_checker_callback != nullptr);
185   std::vector<HostLapackInfo> host_lapack_infos;
186   if (dev_lapack_infos.empty()) {
187     info_checker_callback(Status::OK(), host_lapack_infos);
188     return;
189   }
190 
191   // Launch memcpys to copy info back from the device to the host.
192   for (const auto& dev_lapack_info : dev_lapack_infos) {
193     bool success = true;
194     auto host_copy = dev_lapack_info.CopyToHost(&success);
195     OP_REQUIRES(
196         solver->context(), success,
197         errors::Internal(
198             "Failed to launch copy of dev_lapack_info to host, debug_info = ",
199             dev_lapack_info.debug_info()));
200     host_lapack_infos.push_back(std::move(host_copy));
201   }
202 
203   // This callback checks that all batch items in all calls were processed
204   // successfully and passes status to the info_checker_callback accordingly.
205   auto* stream = solver->context()->op_device_context()->stream();
206   auto wrapped_info_checker_callback =
207       [stream](
208           CudaSolver* solver,
209           std::function<void(const Status&, const std::vector<HostLapackInfo>&)>
210               info_checker_callback,
211           std::vector<HostLapackInfo> host_lapack_infos) {
212         ScopedActivateExecutorContext scoped_activation{stream->parent()};
213         Status status;
214         for (const auto& host_lapack_info : host_lapack_infos) {
215           for (int i = 0; i < host_lapack_info.size() && status.ok(); ++i) {
216             const int info_value = host_lapack_info(i);
217             if (info_value != 0) {
218               status = errors::InvalidArgument(
219                   "Got info = ", info_value, " for batch index ", i,
220                   ", expected info = 0. Debug_info = ",
221                   host_lapack_info.debug_info());
222             }
223           }
224           if (!status.ok()) {
225             break;
226           }
227         }
228         // Delete solver to release temp tensor refs.
229         delete solver;
230 
231         // Delegate further error checking to provided functor.
232         info_checker_callback(status, host_lapack_infos);
233       };
234   // Note: An std::function cannot have unique_ptr arguments (it must be copy
235   // constructible and therefore so must its arguments). Therefore, we release
236   // solver into a raw pointer to be deleted at the end of
237   // wrapped_info_checker_callback.
238   // Release ownership of solver. It will be deleted in the cb callback.
239   auto solver_raw_ptr = solver.release();
240   auto cb =
241       std::bind(wrapped_info_checker_callback, solver_raw_ptr,
242                 std::move(info_checker_callback), std::move(host_lapack_infos));
243 
244   solver_raw_ptr->context()
245       ->device()
246       ->tensorflow_gpu_device_info()
247       ->event_mgr->ThenExecute(stream, std::move(cb));
248 }
249 
250 // static
CheckLapackInfoAndDeleteSolverAsync(std::unique_ptr<CudaSolver> solver,const std::vector<DeviceLapackInfo> & dev_lapack_info,AsyncOpKernel::DoneCallback done)251 void CudaSolver::CheckLapackInfoAndDeleteSolverAsync(
252     std::unique_ptr<CudaSolver> solver,
253     const std::vector<DeviceLapackInfo>& dev_lapack_info,
254     AsyncOpKernel::DoneCallback done) {
255   OpKernelContext* context = solver->context();
256   auto wrapped_done = [context, done](
257                           const Status& status,
258                           const std::vector<HostLapackInfo>& /* unused */) {
259     if (done != nullptr) {
260       OP_REQUIRES_OK_ASYNC(context, status, done);
261       done();
262     } else {
263       OP_REQUIRES_OK(context, status);
264     }
265   };
266   CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_lapack_info,
267                                       wrapped_done);
268 }
269 
270 // Allocates a temporary tensor. The CudaSolver object maintains a
271 // TensorReference to the underlying Tensor to prevent it from being deallocated
272 // prematurely.
allocate_scoped_tensor(DataType type,const TensorShape & shape,Tensor * out_temp)273 Status CudaSolver::allocate_scoped_tensor(DataType type,
274                                           const TensorShape& shape,
275                                           Tensor* out_temp) {
276   const Status status = context_->allocate_temp(type, shape, out_temp);
277   if (status.ok()) {
278     scratch_tensor_refs_.emplace_back(*out_temp);
279   }
280   return status;
281 }
282 
forward_input_or_allocate_scoped_tensor(gtl::ArraySlice<int> candidate_input_indices,DataType type,const TensorShape & shape,Tensor * out_temp)283 Status CudaSolver::forward_input_or_allocate_scoped_tensor(
284     gtl::ArraySlice<int> candidate_input_indices, DataType type,
285     const TensorShape& shape, Tensor* out_temp) {
286   const Status status = context_->forward_input_or_allocate_temp(
287       candidate_input_indices, type, shape, out_temp);
288   if (status.ok()) {
289     scratch_tensor_refs_.emplace_back(*out_temp);
290   }
291   return status;
292 }
293 
294 // Macro that specializes a solver method for all 4 standard
295 // numeric types.
296 #define TF_CALL_LAPACK_TYPES(m) \
297   m(float, S) m(double, D) m(std::complex<float>, C) m(std::complex<double>, Z)
298 #define TF_CALL_LAPACK_TYPES_NO_COMPLEX(m) m(float, S) m(double, D)
299 
300 // Macros to construct cusolverDn method names.
301 #define DN_SOLVER_FN(method, type_prefix) cusolverDn##type_prefix##method
302 #define DN_SOLVER_NAME(method, type_prefix) "cusolverDn" #type_prefix #method
303 #define DN_BUFSIZE_FN(method, type_prefix) \
304   cusolverDn##type_prefix##method##_bufferSize
305 
306 // Macros to construct cublas method names.
307 #define BLAS_SOLVER_FN(method, type_prefix) cublas##type_prefix##method
308 #define BLAS_SOLVER_NAME(method, type_prefix) "cublas" #type_prefix #method
309 
310 //=============================================================================
311 // Wrappers of cuSolverDN computational methods begin here.
312 //
313 // WARNING to implementers: The function signatures listed in the online docs
314 // are sometimes inaccurate, e.g., are missing 'const' on pointers
315 // to immutable arguments, while the actual headers have them as expected.
316 // Check the actual declarations in the cusolver_api.h header file.
317 //
318 // NOTE: The cuSolver functions called below appear not to be threadsafe.
319 // so we put a global lock around the calls. Since these functions only put a
320 // kernel on the shared stream, it is not a big performance hit.
321 // TODO(rmlarsen): Investigate if the locking is still needed in Cuda 9.
322 //=============================================================================
323 
324 template <typename Scalar, typename SolverFnT>
GeamImpl(SolverFnT solver,cublasHandle_t cublas_handle,cublasOperation_t transa,cublasOperation_t transb,int m,int n,const Scalar * alpha,const Scalar * A,int lda,const Scalar * beta,const Scalar * B,int ldb,Scalar * C,int ldc)325 static inline Status GeamImpl(SolverFnT solver, cublasHandle_t cublas_handle,
326                               cublasOperation_t transa,
327                               cublasOperation_t transb, int m, int n,
328                               const Scalar* alpha, /* host or device pointer */
329                               const Scalar* A, int lda,
330                               const Scalar* beta, /* host or device pointer */
331                               const Scalar* B, int ldb, Scalar* C, int ldc) {
332   mutex_lock lock(handle_map_mutex);
333   using CudaScalar = typename CUDAComplexT<Scalar>::type;
334   TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, transa, transb, m, n,
335                                    reinterpret_cast<const CudaScalar*>(alpha),
336                                    reinterpret_cast<const CudaScalar*>(A), lda,
337                                    reinterpret_cast<const CudaScalar*>(beta),
338                                    reinterpret_cast<const CudaScalar*>(B), ldb,
339                                    reinterpret_cast<CudaScalar*>(C), ldc));
340   return Status::OK();
341 }
342 
343 #define GEAM_INSTANCE(Scalar, type_prefix)                                     \
344   template <>                                                                  \
345   Status CudaSolver::Geam<Scalar>(                                             \
346       cublasOperation_t transa, cublasOperation_t transb, int m, int n,        \
347       const Scalar* alpha, /* host or device pointer */                        \
348       const Scalar* A, int lda,                                                \
349       const Scalar* beta, /* host or device pointer */                         \
350       const Scalar* B, int ldb, Scalar* C, int ldc) const {                    \
351     return GeamImpl(BLAS_SOLVER_FN(geam, type_prefix), cublas_handle_, transa, \
352                     transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc);        \
353   }
354 
355 TF_CALL_LAPACK_TYPES(GEAM_INSTANCE);
356 
357 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
PotrfImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cublasFillMode_t uplo,int n,Scalar * A,int lda,int * dev_lapack_info)358 static inline Status PotrfImpl(BufSizeFnT bufsize, SolverFnT solver,
359                                CudaSolver* cuda_solver,
360                                OpKernelContext* context,
361                                cusolverDnHandle_t cusolver_dn_handle,
362                                cublasFillMode_t uplo, int n, Scalar* A, int lda,
363                                int* dev_lapack_info) {
364   mutex_lock lock(handle_map_mutex);
365   /* Get amount of workspace memory required. */
366   int lwork;
367   TF_RETURN_IF_CUSOLVER_ERROR(
368       bufsize(cusolver_dn_handle, uplo, n, CUDAComplex(A), lda, &lwork));
369   /* Allocate device memory for workspace. */
370   auto dev_workspace =
371       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
372   /* Launch the solver kernel. */
373   TF_RETURN_IF_CUSOLVER_ERROR(solver(
374       cusolver_dn_handle, uplo, n, CUDAComplex(A), lda,
375       CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));
376   return Status::OK();
377 }
378 
379 #define POTRF_INSTANCE(Scalar, type_prefix)                                  \
380   template <>                                                                \
381   Status CudaSolver::Potrf<Scalar>(cublasFillMode_t uplo, int n, Scalar* A,  \
382                                    int lda, int* dev_lapack_info) {          \
383     return PotrfImpl(DN_BUFSIZE_FN(potrf, type_prefix),                      \
384                      DN_SOLVER_FN(potrf, type_prefix), this, context_,       \
385                      cusolver_dn_handle_, uplo, n, A, lda, dev_lapack_info); \
386   }
387 
388 TF_CALL_LAPACK_TYPES(POTRF_INSTANCE);
389 
390 #if CUDA_VERSION >= 9020
391 template <typename Scalar, typename SolverFnT>
PotrfBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cublasFillMode_t uplo,int n,const Scalar * const host_a_dev_ptrs[],int lda,DeviceLapackInfo * dev_lapack_info,int batch_size)392 static inline Status PotrfBatchedImpl(
393     SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
394     cusolverDnHandle_t cusolver_dn_handle, cublasFillMode_t uplo, int n,
395     const Scalar* const host_a_dev_ptrs[], int lda,
396     DeviceLapackInfo* dev_lapack_info, int batch_size) {
397   mutex_lock lock(handle_map_mutex);
398   using CudaScalar = typename CUDAComplexT<Scalar>::type;
399   ScratchSpace<uint8> dev_a_dev_ptrs =
400       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
401                                           /* on_host */ false);
402   if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
403                         host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
404     return errors::Internal("PotrfBatched: failed to copy pointers to device");
405   }
406   TF_RETURN_IF_CUSOLVER_ERROR(
407       solver(cusolver_dn_handle, uplo, n,
408              reinterpret_cast<CudaScalar**>(dev_a_dev_ptrs.mutable_data()), lda,
409              dev_lapack_info->mutable_data(), batch_size));
410   return Status::OK();
411 }
412 
413 #define POTRF_BATCHED_INSTANCE(Scalar, type_prefix)                        \
414   template <>                                                              \
415   Status CudaSolver::PotrfBatched(                                         \
416       cublasFillMode_t uplo, int n, const Scalar* const host_a_dev_ptrs[], \
417       int lda, DeviceLapackInfo* dev_lapack_info, int batch_size) {        \
418     return PotrfBatchedImpl(DN_SOLVER_FN(potrfBatched, type_prefix), this, \
419                             context_, cusolver_dn_handle_, uplo, n,        \
420                             host_a_dev_ptrs, lda, dev_lapack_info,         \
421                             batch_size);                                   \
422   }
423 
424 TF_CALL_LAPACK_TYPES(POTRF_BATCHED_INSTANCE);
425 #endif  // CUDA_VERSION >= 9020
426 
427 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
GetrfImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,int m,int n,Scalar * A,int lda,int * dev_pivots,int * dev_lapack_info)428 static inline Status GetrfImpl(BufSizeFnT bufsize, SolverFnT solver,
429                                CudaSolver* cuda_solver,
430                                OpKernelContext* context,
431                                cusolverDnHandle_t cusolver_dn_handle, int m,
432                                int n, Scalar* A, int lda, int* dev_pivots,
433                                int* dev_lapack_info) {
434   mutex_lock lock(handle_map_mutex);
435   /* Get amount of workspace memory required. */
436   int lwork;
437   TF_RETURN_IF_CUSOLVER_ERROR(
438       bufsize(cusolver_dn_handle, m, n, CUDAComplex(A), lda, &lwork));
439   /* Allocate device memory for workspace. */
440   auto dev_workspace =
441       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
442   /* Launch the solver kernel. */
443   TF_RETURN_IF_CUSOLVER_ERROR(solver(
444       cusolver_dn_handle, m, n, CUDAComplex(A), lda,
445       CUDAComplex(dev_workspace.mutable_data()), dev_pivots, dev_lapack_info));
446   return Status::OK();
447 }
448 
449 #define GETRF_INSTANCE(Scalar, type_prefix)                                 \
450   template <>                                                               \
451   Status CudaSolver::Getrf<Scalar>(int m, int n, Scalar* A, int lda,        \
452                                    int* dev_pivots, int* dev_lapack_info) { \
453     return GetrfImpl(DN_BUFSIZE_FN(getrf, type_prefix),                     \
454                      DN_SOLVER_FN(getrf, type_prefix), this, context_,      \
455                      cusolver_dn_handle_, m, n, A, lda, dev_pivots,         \
456                      dev_lapack_info);                                      \
457   }
458 
459 TF_CALL_LAPACK_TYPES(GETRF_INSTANCE);
460 
461 template <typename Scalar, typename SolverFnT>
GetrsImpl(SolverFnT solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cublasOperation_t trans,int n,int nrhs,const Scalar * A,int lda,const int * pivots,Scalar * B,int ldb,int * dev_lapack_info)462 static inline Status GetrsImpl(SolverFnT solver, OpKernelContext* context,
463                                cusolverDnHandle_t cusolver_dn_handle,
464                                cublasOperation_t trans, int n, int nrhs,
465                                const Scalar* A, int lda, const int* pivots,
466                                Scalar* B, int ldb, int* dev_lapack_info) {
467   mutex_lock lock(handle_map_mutex);
468   /* Launch the solver kernel. */
469   TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, trans, n, nrhs,
470                                      CUDAComplex(A), lda, pivots,
471                                      CUDAComplex(B), ldb, dev_lapack_info));
472   return Status::OK();
473 }
474 
475 #define GETRS_INSTANCE(Scalar, type_prefix)                                  \
476   template <>                                                                \
477   Status CudaSolver::Getrs<Scalar>(                                          \
478       cublasOperation_t trans, int n, int nrhs, const Scalar* A, int lda,    \
479       const int* pivots, Scalar* B, int ldb, int* dev_lapack_info) const {   \
480     return GetrsImpl(DN_SOLVER_FN(getrs, type_prefix), context_,             \
481                      cusolver_dn_handle_, trans, n, nrhs, A, lda, pivots, B, \
482                      ldb, dev_lapack_info);                                  \
483   }
484 
485 TF_CALL_LAPACK_TYPES(GETRS_INSTANCE);
486 
487 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
GeqrfImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,int m,int n,Scalar * A,int lda,Scalar * tau,int * dev_lapack_info)488 static inline Status GeqrfImpl(BufSizeFnT bufsize, SolverFnT solver,
489                                CudaSolver* cuda_solver,
490                                OpKernelContext* context,
491                                cusolverDnHandle_t cusolver_dn_handle, int m,
492                                int n, Scalar* A, int lda, Scalar* tau,
493                                int* dev_lapack_info) {
494   mutex_lock lock(handle_map_mutex);
495   /* Get amount of workspace memory required. */
496   int lwork;
497   TF_RETURN_IF_CUSOLVER_ERROR(
498       bufsize(cusolver_dn_handle, m, n, CUDAComplex(A), lda, &lwork));
499   /* Allocate device memory for workspace. */
500   auto dev_workspace =
501       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
502   /* Launch the solver kernel. */
503   TF_RETURN_IF_CUSOLVER_ERROR(solver(
504       cusolver_dn_handle, m, n, CUDAComplex(A), lda, CUDAComplex(tau),
505       CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));
506   return Status::OK();
507 }
508 
509 #define GEQRF_INSTANCE(Scalar, type_prefix)                                    \
510   template <>                                                                  \
511   Status CudaSolver::Geqrf<Scalar>(int m, int n, Scalar* A, int lda,           \
512                                    Scalar* tau, int* dev_lapack_info) {        \
513     return GeqrfImpl(DN_BUFSIZE_FN(geqrf, type_prefix),                        \
514                      DN_SOLVER_FN(geqrf, type_prefix), this, context_,         \
515                      cusolver_dn_handle_, m, n, A, lda, tau, dev_lapack_info); \
516   }
517 
518 TF_CALL_LAPACK_TYPES(GEQRF_INSTANCE);
519 
520 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
UnmqrImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cublasSideMode_t side,cublasOperation_t trans,int m,int n,int k,const Scalar * dev_a,int lda,const Scalar * dev_tau,Scalar * dev_c,int ldc,int * dev_lapack_info)521 static inline Status UnmqrImpl(BufSizeFnT bufsize, SolverFnT solver,
522                                CudaSolver* cuda_solver,
523                                OpKernelContext* context,
524                                cusolverDnHandle_t cusolver_dn_handle,
525                                cublasSideMode_t side, cublasOperation_t trans,
526                                int m, int n, int k, const Scalar* dev_a,
527                                int lda, const Scalar* dev_tau, Scalar* dev_c,
528                                int ldc, int* dev_lapack_info) {
529   mutex_lock lock(handle_map_mutex);
530   /* Get amount of workspace memory required. */
531   int lwork;
532   TF_RETURN_IF_CUSOLVER_ERROR(
533       bufsize(cusolver_dn_handle, side, trans, m, n, k, CUDAComplex(dev_a), lda,
534               CUDAComplex(dev_tau), CUDAComplex(dev_c), ldc, &lwork));
535   /* Allocate device memory for workspace. */
536   auto dev_workspace =
537       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
538   /* Launch the solver kernel. */
539   TF_RETURN_IF_CUSOLVER_ERROR(solver(
540       cusolver_dn_handle, side, trans, m, n, k, CUDAComplex(dev_a), lda,
541       CUDAComplex(dev_tau), CUDAComplex(dev_c), ldc,
542       CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));
543   return Status::OK();
544 }
545 
546 // Unfortunately the LAPACK function name differs for the real and complex case
547 // (complex ones are prefixed with "UN" for "unitary"), so we instantiate each
548 // one separately.
549 #define UNMQR_INSTANCE(Scalar, function_prefix, type_prefix)                  \
550   template <>                                                                 \
551   Status CudaSolver::Unmqr(cublasSideMode_t side, cublasOperation_t trans,    \
552                            int m, int n, int k, const Scalar* dev_a, int lda, \
553                            const Scalar* dev_tau, Scalar* dev_c, int ldc,     \
554                            int* dev_lapack_info) {                            \
555     return UnmqrImpl(DN_BUFSIZE_FN(function_prefix##mqr, type_prefix),        \
556                      DN_SOLVER_FN(function_prefix##mqr, type_prefix), this,   \
557                      context_, cusolver_dn_handle_, side, trans, m, n, k,     \
558                      dev_a, lda, dev_tau, dev_c, ldc, dev_lapack_info);       \
559   }
560 
561 UNMQR_INSTANCE(float, or, S);
562 UNMQR_INSTANCE(double, or, D);
563 UNMQR_INSTANCE(complex64, un, C);
564 UNMQR_INSTANCE(complex128, un, Z);
565 
566 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
UngqrImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,int m,int n,int k,Scalar * dev_a,int lda,const Scalar * dev_tau,int * dev_lapack_info)567 static inline Status UngqrImpl(BufSizeFnT bufsize, SolverFnT solver,
568                                CudaSolver* cuda_solver,
569                                OpKernelContext* context,
570                                cusolverDnHandle_t cusolver_dn_handle, int m,
571                                int n, int k, Scalar* dev_a, int lda,
572                                const Scalar* dev_tau, int* dev_lapack_info) {
573   mutex_lock lock(handle_map_mutex);
574   /* Get amount of workspace memory required. */
575   int lwork;
576   TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, k,
577                                       CUDAComplex(dev_a), lda,
578                                       CUDAComplex(dev_tau), &lwork));
579   /* Allocate device memory for workspace. */
580   auto dev_workspace =
581       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
582   /* Launch the solver kernel. */
583   TF_RETURN_IF_CUSOLVER_ERROR(
584       solver(cusolver_dn_handle, m, n, k, CUDAComplex(dev_a), lda,
585              CUDAComplex(dev_tau), CUDAComplex(dev_workspace.mutable_data()),
586              lwork, dev_lapack_info));
587   return Status::OK();
588 }
589 
590 #define UNGQR_INSTANCE(Scalar, function_prefix, type_prefix)                \
591   template <>                                                               \
592   Status CudaSolver::Ungqr(int m, int n, int k, Scalar* dev_a, int lda,     \
593                            const Scalar* dev_tau, int* dev_lapack_info) {   \
594     return UngqrImpl(DN_BUFSIZE_FN(function_prefix##gqr, type_prefix),      \
595                      DN_SOLVER_FN(function_prefix##gqr, type_prefix), this, \
596                      context_, cusolver_dn_handle_, m, n, k, dev_a, lda,    \
597                      dev_tau, dev_lapack_info);                             \
598   }
599 
600 UNGQR_INSTANCE(float, or, S);
601 UNGQR_INSTANCE(double, or, D);
602 UNGQR_INSTANCE(complex64, un, C);
603 UNGQR_INSTANCE(complex128, un, Z);
604 
605 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
HeevdImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,Scalar * dev_A,int lda,typename Eigen::NumTraits<Scalar>::Real * dev_W,int * dev_lapack_info)606 static inline Status HeevdImpl(BufSizeFnT bufsize, SolverFnT solver,
607                                CudaSolver* cuda_solver,
608                                OpKernelContext* context,
609                                cusolverDnHandle_t cusolver_dn_handle,
610                                cusolverEigMode_t jobz, cublasFillMode_t uplo,
611                                int n, Scalar* dev_A, int lda,
612                                typename Eigen::NumTraits<Scalar>::Real* dev_W,
613                                int* dev_lapack_info) {
614   mutex_lock lock(handle_map_mutex);
615   /* Get amount of workspace memory required. */
616   int lwork;
617   TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, jobz, uplo, n,
618                                       CUDAComplex(dev_A), lda,
619                                       CUDAComplex(dev_W), &lwork));
620   /* Allocate device memory for workspace. */
621   auto dev_workspace =
622       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
623   /* Launch the solver kernel. */
624   TF_RETURN_IF_CUSOLVER_ERROR(
625       solver(cusolver_dn_handle, jobz, uplo, n, CUDAComplex(dev_A), lda,
626              CUDAComplex(dev_W), CUDAComplex(dev_workspace.mutable_data()),
627              lwork, dev_lapack_info));
628   return Status::OK();
629 }
630 
631 #define HEEVD_INSTANCE(Scalar, function_prefix, type_prefix)                   \
632   template <>                                                                  \
633   Status CudaSolver::Heevd(cusolverEigMode_t jobz, cublasFillMode_t uplo,      \
634                            int n, Scalar* dev_A, int lda,                      \
635                            typename Eigen::NumTraits<Scalar>::Real* dev_W,     \
636                            int* dev_lapack_info) {                             \
637     return HeevdImpl(DN_BUFSIZE_FN(function_prefix##evd, type_prefix),         \
638                      DN_SOLVER_FN(function_prefix##evd, type_prefix), this,    \
639                      context_, cusolver_dn_handle_, jobz, uplo, n, dev_A, lda, \
640                      dev_W, dev_lapack_info);                                  \
641   }
642 
643 HEEVD_INSTANCE(float, sy, S);
644 HEEVD_INSTANCE(double, sy, D);
645 HEEVD_INSTANCE(complex64, he, C);
646 HEEVD_INSTANCE(complex128, he, Z);
647 
648 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
GesvdImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,signed char jobu,signed char jobvt,int m,int n,Scalar * A,int lda,Scalar * S,Scalar * U,int ldu,Scalar * VT,int ldvt,int * dev_lapack_info)649 static inline Status GesvdImpl(
650     BufSizeFnT bufsize, SolverFnT solver, CudaSolver* cuda_solver,
651     OpKernelContext* context, cusolverDnHandle_t cusolver_dn_handle,
652     signed char jobu, signed char jobvt, int m, int n, Scalar* A, int lda,
653     Scalar* S, Scalar* U, int ldu, Scalar* VT, int ldvt, int* dev_lapack_info) {
654   mutex_lock lock(handle_map_mutex);
655   /* Get amount of workspace memory required. */
656   int lwork;
657   TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, &lwork));
658   /* Allocate device memory for workspace. */
659   auto dev_workspace =
660       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
661   TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, jobu, jobvt, m, n,
662                                      CUDAComplex(A), lda, S, CUDAComplex(U),
663                                      ldu, CUDAComplex(VT), ldvt,
664                                      CUDAComplex(dev_workspace.mutable_data()),
665                                      lwork, nullptr, dev_lapack_info));
666   return Status::OK();
667 }
668 
669 #define GESVD_INSTANCE(Scalar, type_prefix)                              \
670   template <>                                                            \
671   Status CudaSolver::Gesvd<Scalar>(                                      \
672       signed char jobu, signed char jobvt, int m, int n, Scalar* dev_A,  \
673       int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT,    \
674       int ldvt, int* dev_lapack_info) {                                  \
675     return GesvdImpl(DN_BUFSIZE_FN(gesvd, type_prefix),                  \
676                      DN_SOLVER_FN(gesvd, type_prefix), this, context_,   \
677                      cusolver_dn_handle_, jobu, jobvt, m, n, dev_A, lda, \
678                      dev_S, dev_U, ldu, dev_VT, ldvt, dev_lapack_info);  \
679   }
680 
681 TF_CALL_LAPACK_TYPES_NO_COMPLEX(GESVD_INSTANCE);
682 
683 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
GesvdjBatchedImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cusolverEigMode_t jobz,int m,int n,Scalar * A,int lda,Scalar * S,Scalar * U,int ldu,Scalar * V,int ldv,int * dev_lapack_info,int batch_size)684 static inline Status GesvdjBatchedImpl(BufSizeFnT bufsize, SolverFnT solver,
685                                        CudaSolver* cuda_solver,
686                                        OpKernelContext* context,
687                                        cusolverDnHandle_t cusolver_dn_handle,
688                                        cusolverEigMode_t jobz, int m, int n,
689                                        Scalar* A, int lda, Scalar* S, Scalar* U,
690                                        int ldu, Scalar* V, int ldv,
691                                        int* dev_lapack_info, int batch_size) {
692   mutex_lock lock(handle_map_mutex);
693   /* Get amount of workspace memory required. */
694   int lwork;
695   /* Default parameters for gesvdj and gesvdjBatched. */
696   gesvdjInfo_t svdj_info;
697   TF_RETURN_IF_CUSOLVER_ERROR(cusolverDnCreateGesvdjInfo(&svdj_info));
698   TF_RETURN_IF_CUSOLVER_ERROR(bufsize(
699       cusolver_dn_handle, jobz, m, n, CUDAComplex(A), lda, S, CUDAComplex(U),
700       ldu, CUDAComplex(V), ldv, &lwork, svdj_info, batch_size));
701   /* Allocate device memory for workspace. */
702   auto dev_workspace =
703       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
704   TF_RETURN_IF_CUSOLVER_ERROR(solver(
705       cusolver_dn_handle, jobz, m, n, CUDAComplex(A), lda, S, CUDAComplex(U),
706       ldu, CUDAComplex(V), ldv, CUDAComplex(dev_workspace.mutable_data()),
707       lwork, dev_lapack_info, svdj_info, batch_size));
708   TF_RETURN_IF_CUSOLVER_ERROR(cusolverDnDestroyGesvdjInfo(svdj_info));
709   return Status::OK();
710 }
711 
712 #define GESVDJBATCHED_INSTANCE(Scalar, type_prefix)                            \
713   template <>                                                                  \
714   Status CudaSolver::GesvdjBatched<Scalar>(                                    \
715       cusolverEigMode_t jobz, int m, int n, Scalar* dev_A, int lda,            \
716       Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_V, int ldv,           \
717       int* dev_lapack_info, int batch_size) {                                  \
718     return GesvdjBatchedImpl(DN_BUFSIZE_FN(gesvdjBatched, type_prefix),        \
719                              DN_SOLVER_FN(gesvdjBatched, type_prefix), this,   \
720                              context_, cusolver_dn_handle_, jobz, m, n, dev_A, \
721                              lda, dev_S, dev_U, ldu, dev_V, ldv,               \
722                              dev_lapack_info, batch_size);                     \
723   }
724 
725 TF_CALL_LAPACK_TYPES_NO_COMPLEX(GESVDJBATCHED_INSTANCE);
726 
727 //=============================================================================
728 // Wrappers of cuBlas computational methods begin here.
729 //
730 // WARNING to implementers: The function signatures listed in the online docs
731 // are sometimes inaccurate, e.g., are missing 'const' on pointers
732 // to immutable arguments, while the actual headers have them as expected.
733 // Check the actual declarations in the cublas_api.h header file.
734 //=============================================================================
735 template <typename Scalar, typename SolverFnT>
GetrfBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,int n,const Scalar * const host_a_dev_ptrs[],int lda,int * dev_pivots,DeviceLapackInfo * dev_lapack_info,int batch_size)736 static inline Status GetrfBatchedImpl(SolverFnT solver, CudaSolver* cuda_solver,
737                                       OpKernelContext* context,
738                                       cublasHandle_t cublas_handle, int n,
739                                       const Scalar* const host_a_dev_ptrs[],
740                                       int lda, int* dev_pivots,
741                                       DeviceLapackInfo* dev_lapack_info,
742                                       int batch_size) {
743   mutex_lock lock(handle_map_mutex);
744   using CudaScalar = typename CUDAComplexT<Scalar>::type;
745   ScratchSpace<uint8> dev_a_dev_ptrs =
746       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
747                                           /* on_host */ false);
748   if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
749                         host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
750     return errors::Internal("GetrfBatched: failed to copy pointers to device");
751   }
752   TF_RETURN_IF_CUBLAS_ERROR(
753       solver(cublas_handle, n,
754              reinterpret_cast<CudaScalar**>(dev_a_dev_ptrs.mutable_data()), lda,
755              dev_pivots, dev_lapack_info->mutable_data(), batch_size));
756   return Status::OK();
757 }
758 
759 #define GETRF_BATCHED_INSTANCE(Scalar, type_prefix)                            \
760   template <>                                                                  \
761   Status CudaSolver::GetrfBatched(                                             \
762       int n, const Scalar* const host_a_dev_ptrs[], int lda, int* dev_pivots,  \
763       DeviceLapackInfo* dev_lapack_info, int batch_size) {                     \
764     return GetrfBatchedImpl(BLAS_SOLVER_FN(getrfBatched, type_prefix), this,   \
765                             context_, cublas_handle_, n, host_a_dev_ptrs, lda, \
766                             dev_pivots, dev_lapack_info, batch_size);          \
767   }
768 
769 TF_CALL_LAPACK_TYPES(GETRF_BATCHED_INSTANCE);
770 
771 template <typename Scalar, typename SolverFnT>
GetrsBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,cublasOperation_t trans,int n,int nrhs,const Scalar * const host_a_dev_ptrs[],int lda,const int * dev_pivots,const Scalar * const host_b_dev_ptrs[],int ldb,int * host_lapack_info,int batch_size)772 static inline Status GetrsBatchedImpl(
773     SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
774     cublasHandle_t cublas_handle, cublasOperation_t trans, int n, int nrhs,
775     const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots,
776     const Scalar* const host_b_dev_ptrs[], int ldb, int* host_lapack_info,
777     int batch_size) {
778   mutex_lock lock(handle_map_mutex);
779   using CudaScalar = typename CUDAComplexT<Scalar>::type;
780   ScratchSpace<uint8> dev_a_dev_ptrs =
781       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
782                                           /* on_host */ false);
783   ScratchSpace<uint8> dev_b_dev_ptrs =
784       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
785                                           /* on_host */ false);
786   if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
787                         host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
788     return errors::Internal("GetrsBatched: failed to copy pointers to device");
789   }
790   if (!CopyHostToDevice(context, dev_b_dev_ptrs.mutable_data() /* dest */,
791                         host_b_dev_ptrs /* source */, dev_b_dev_ptrs.bytes())) {
792     return errors::Internal("GetrsBatched: failed to copy pointers to device");
793   }
794   TF_RETURN_IF_CUBLAS_ERROR(solver(
795       cublas_handle, trans, n, nrhs,
796       reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()), lda,
797       dev_pivots, reinterpret_cast<CudaScalar**>(dev_b_dev_ptrs.mutable_data()),
798       ldb, host_lapack_info, batch_size));
799   return Status::OK();
800 }
801 
802 #define GETRS_BATCHED_INSTANCE(Scalar, type_prefix)                            \
803   template <>                                                                  \
804   Status CudaSolver::GetrsBatched(                                             \
805       cublasOperation_t trans, int n, int nrhs,                                \
806       const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots,   \
807       const Scalar* const host_b_dev_ptrs[], int ldb, int* host_lapack_info,   \
808       int batch_size) {                                                        \
809     return GetrsBatchedImpl(reinterpret_cast<getrs_##type_prefix*>(            \
810                                 BLAS_SOLVER_FN(getrsBatched, type_prefix)),    \
811                             this, context_, cublas_handle_, trans, n, nrhs,    \
812                             host_a_dev_ptrs, lda, dev_pivots, host_b_dev_ptrs, \
813                             ldb, host_lapack_info, batch_size);                \
814   }
815 
816 TF_CALL_LAPACK_TYPES(GETRS_BATCHED_INSTANCE);
817 
818 template <typename Scalar, typename SolverFnT>
GetriBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,int n,const Scalar * const host_a_dev_ptrs[],int lda,const int * dev_pivots,const Scalar * const host_a_inv_dev_ptrs[],int ldainv,DeviceLapackInfo * dev_lapack_info,int batch_size)819 static inline Status GetriBatchedImpl(
820     SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
821     cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[],
822     int lda, const int* dev_pivots, const Scalar* const host_a_inv_dev_ptrs[],
823     int ldainv, DeviceLapackInfo* dev_lapack_info, int batch_size) {
824   mutex_lock lock(handle_map_mutex);
825   using CudaScalar = typename CUDAComplexT<Scalar>::type;
826   ScratchSpace<uint8> dev_a_dev_ptrs =
827       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
828                                           /* on_host */ false);
829   ScratchSpace<uint8> dev_a_inv_dev_ptrs = cuda_solver->GetScratchSpace<uint8>(
830       sizeof(CudaScalar*) * batch_size, "", /* on_host */ false);
831   if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
832                         host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes()) ||
833       !CopyHostToDevice(context, dev_a_inv_dev_ptrs.mutable_data(),
834                         host_a_inv_dev_ptrs, dev_a_inv_dev_ptrs.bytes())) {
835     return errors::Internal("GetriBatched: failed to copy pointers to device");
836   }
837   TF_RETURN_IF_CUBLAS_ERROR(
838       solver(cublas_handle, n,
839              reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()),
840              lda, dev_pivots,
841              reinterpret_cast<CudaScalar**>(dev_a_inv_dev_ptrs.mutable_data()),
842              ldainv, dev_lapack_info->mutable_data(), batch_size));
843   return Status::OK();
844 }
845 
846 #define GETRI_BATCHED_INSTANCE(Scalar, type_prefix)                          \
847   template <>                                                                \
848   Status CudaSolver::GetriBatched(                                           \
849       int n, const Scalar* const host_a_dev_ptrs[], int lda,                 \
850       const int* dev_pivots, const Scalar* const host_a_inv_dev_ptrs[],      \
851       int ldainv, DeviceLapackInfo* dev_lapack_info, int batch_size) {       \
852     return GetriBatchedImpl(                                                 \
853         reinterpret_cast<getri_##type_prefix*>(                              \
854             BLAS_SOLVER_FN(getriBatched, type_prefix)),                      \
855         this, context_, cublas_handle_, n, host_a_dev_ptrs, lda, dev_pivots, \
856         host_a_inv_dev_ptrs, ldainv, dev_lapack_info, batch_size);           \
857   }
858 
859 TF_CALL_LAPACK_TYPES(GETRI_BATCHED_INSTANCE);
860 
861 template <typename Scalar, typename SolverFnT>
MatInvBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,int n,const Scalar * const host_a_dev_ptrs[],int lda,const Scalar * const host_a_inv_dev_ptrs[],int ldainv,DeviceLapackInfo * dev_lapack_info,int batch_size)862 static inline Status MatInvBatchedImpl(
863     SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
864     cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[],
865     int lda, const Scalar* const host_a_inv_dev_ptrs[], int ldainv,
866     DeviceLapackInfo* dev_lapack_info, int batch_size) {
867   mutex_lock lock(handle_map_mutex);
868   using CudaScalar = typename CUDAComplexT<Scalar>::type;
869   ScratchSpace<uint8> dev_a_dev_ptrs =
870       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
871                                           /* on_host */ false);
872   ScratchSpace<uint8> dev_a_inv_dev_ptrs = cuda_solver->GetScratchSpace<uint8>(
873       sizeof(CudaScalar*) * batch_size, "", /* on_host */ false);
874   if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
875                         host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes()) ||
876       !CopyHostToDevice(context, dev_a_inv_dev_ptrs.mutable_data(),
877                         host_a_inv_dev_ptrs, dev_a_inv_dev_ptrs.bytes())) {
878     return errors::Internal("MatInvBatched: failed to copy pointers to device");
879   }
880   TF_RETURN_IF_CUBLAS_ERROR(solver(
881       cublas_handle, n,
882       reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()), lda,
883       reinterpret_cast<CudaScalar**>(dev_a_inv_dev_ptrs.mutable_data()), ldainv,
884       dev_lapack_info->mutable_data(), batch_size));
885   return Status::OK();
886 }
887 
888 #define MATINV_BATCHED_INSTANCE(Scalar, type_prefix)                          \
889   template <>                                                                 \
890   Status CudaSolver::MatInvBatched(                                           \
891       int n, const Scalar* const host_a_dev_ptrs[], int lda,                  \
892       const Scalar* const host_a_inv_dev_ptrs[], int ldainv,                  \
893       DeviceLapackInfo* dev_lapack_info, int batch_size) {                    \
894     return MatInvBatchedImpl(reinterpret_cast<matinv_##type_prefix*>(         \
895                                  BLAS_SOLVER_FN(matinvBatched, type_prefix)), \
896                              this, context_, cublas_handle_, n,               \
897                              host_a_dev_ptrs, lda, host_a_inv_dev_ptrs,       \
898                              ldainv, dev_lapack_info, batch_size);            \
899   }
900 
901 TF_CALL_LAPACK_TYPES(MATINV_BATCHED_INSTANCE);
902 
903 template <typename Scalar, typename SolverFnT>
TrsmImpl(SolverFnT solver,cublasHandle_t cublas_handle,cublasSideMode_t side,cublasFillMode_t uplo,cublasOperation_t trans,cublasDiagType_t diag,int m,int n,const Scalar * alpha,const Scalar * A,int lda,Scalar * B,int ldb)904 static inline Status TrsmImpl(SolverFnT solver, cublasHandle_t cublas_handle,
905                               cublasSideMode_t side, cublasFillMode_t uplo,
906                               cublasOperation_t trans, cublasDiagType_t diag,
907                               int m, int n,
908                               const Scalar* alpha, /* host or device pointer */
909                               const Scalar* A, int lda, Scalar* B, int ldb) {
910   mutex_lock lock(handle_map_mutex);
911   using CudaScalar = typename CUDAComplexT<Scalar>::type;
912   TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, side, uplo, trans, diag, m, n,
913                                    reinterpret_cast<const CudaScalar*>(alpha),
914                                    reinterpret_cast<const CudaScalar*>(A), lda,
915                                    reinterpret_cast<CudaScalar*>(B), ldb));
916   return Status::OK();
917 }
918 
919 #define TRSM_INSTANCE(Scalar, type_prefix)                                   \
920   template <>                                                                \
921   Status CudaSolver::Trsm<Scalar>(                                           \
922       cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, \
923       cublasDiagType_t diag, int m, int n,                                   \
924       const Scalar* alpha, /* host or device pointer */                      \
925       const Scalar* A, int lda, Scalar* B, int ldb) {                        \
926     return TrsmImpl(BLAS_SOLVER_FN(trsm, type_prefix), cublas_handle_, side, \
927                     uplo, trans, diag, m, n, alpha, A, lda, B, ldb);         \
928   }
929 
930 TF_CALL_LAPACK_TYPES(TRSM_INSTANCE);
931 
932 template <typename Scalar, typename SolverFnT>
TrsvImpl(SolverFnT solver,cublasHandle_t cublas_handle,cublasFillMode_t uplo,cublasOperation_t trans,cublasDiagType_t diag,int n,const Scalar * A,int lda,Scalar * x,int incx)933 static inline Status TrsvImpl(SolverFnT solver, cublasHandle_t cublas_handle,
934                               cublasFillMode_t uplo, cublasOperation_t trans,
935                               cublasDiagType_t diag, int n, const Scalar* A,
936                               int lda, Scalar* x, int incx) {
937   mutex_lock lock(handle_map_mutex);
938   using CudaScalar = typename CUDAComplexT<Scalar>::type;
939   TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, uplo, trans, diag, n,
940                                    reinterpret_cast<const CudaScalar*>(A), lda,
941                                    reinterpret_cast<CudaScalar*>(x), incx));
942   return Status::OK();
943 }
944 
945 #define TRSV_INSTANCE(Scalar, type_prefix)                                   \
946   template <>                                                                \
947   Status CudaSolver::Trsv<Scalar>(                                           \
948       cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, \
949       int n, const Scalar* A, int lda, Scalar* x, int incx) {                \
950     return TrsvImpl(BLAS_SOLVER_FN(trsv, type_prefix), cublas_handle_, uplo, \
951                     trans, diag, n, A, lda, x, incx);                        \
952   }
953 
954 TF_CALL_LAPACK_TYPES(TRSV_INSTANCE);
955 
956 template <typename Scalar, typename SolverFnT>
TrsmBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,cublasSideMode_t side,cublasFillMode_t uplo,cublasOperation_t trans,cublasDiagType_t diag,int m,int n,const Scalar * alpha,const Scalar * const host_a_dev_ptrs[],int lda,Scalar * host_b_dev_ptrs[],int ldb,int batch_size)957 static inline Status TrsmBatchedImpl(
958     SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
959     cublasHandle_t cublas_handle, cublasSideMode_t side, cublasFillMode_t uplo,
960     cublasOperation_t trans, cublasDiagType_t diag, int m, int n,
961     const Scalar* alpha, const Scalar* const host_a_dev_ptrs[], int lda,
962     Scalar* host_b_dev_ptrs[], int ldb, int batch_size) {
963   mutex_lock lock(handle_map_mutex);
964   using CudaScalar = typename CUDAComplexT<Scalar>::type;
965   ScratchSpace<uint8> dev_a_dev_ptrs =
966       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
967                                           /* on_host */ false);
968   ScratchSpace<uint8> dev_b_dev_ptrs =
969       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
970                                           /* on_host */ false);
971   if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
972                         host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
973     return errors::Internal("TrsmBatched: failed to copy pointers to device");
974   }
975   if (!CopyHostToDevice(context, dev_b_dev_ptrs.mutable_data() /* dest */,
976                         host_b_dev_ptrs /* source */, dev_b_dev_ptrs.bytes())) {
977     return errors::Internal("TrsmBatched: failed to copy pointers to device");
978   }
979   TF_RETURN_IF_CUBLAS_ERROR(
980       solver(cublas_handle, side, uplo, trans, diag, m, n,
981              reinterpret_cast<const CudaScalar*>(alpha),
982              reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()),
983              lda, reinterpret_cast<CudaScalar**>(dev_b_dev_ptrs.mutable_data()),
984              ldb, batch_size));
985   return Status::OK();
986 }
987 
988 #define TRSM_BATCHED_INSTANCE(Scalar, type_prefix)                            \
989   template <>                                                                 \
990   Status CudaSolver::TrsmBatched(                                             \
991       cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans,  \
992       cublasDiagType_t diag, int m, int n, const Scalar* alpha,               \
993       const Scalar* const dev_Aarray[], int lda, Scalar* dev_Barray[],        \
994       int ldb, int batch_size) {                                              \
995     return TrsmBatchedImpl(BLAS_SOLVER_FN(trsmBatched, type_prefix), this,    \
996                            context_, cublas_handle_, side, uplo, trans, diag, \
997                            m, n, alpha, dev_Aarray, lda, dev_Barray, ldb,     \
998                            batch_size);                                       \
999   }
1000 
1001 TF_CALL_LAPACK_TYPES(TRSM_BATCHED_INSTANCE);
1002 
1003 }  // namespace tensorflow
1004 
1005 #endif  // GOOGLE_CUDA
1006