1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================
15 */
16 #ifdef GOOGLE_CUDA
17 #include "tensorflow/core/kernels/cuda_solvers.h"
18
19 #include <chrono>
20 #include <complex>
21 #include <unordered_map>
22 #include <vector>
23
24 #include "third_party/gpus/cuda/include/cublas_v2.h"
25 #include "third_party/gpus/cuda/include/cusolverDn.h"
26 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/lib/core/blocking_counter.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/lib/core/stringpiece.h"
32 #include "tensorflow/core/lib/gtl/inlined_vector.h"
33 #include "tensorflow/core/platform/mutex.h"
34 #include "tensorflow/core/platform/stream_executor.h"
35 #include "tensorflow/core/platform/types.h"
36 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
37
38 // The CUDA cublas_api.h API contains const-correctness errors. Instead of
39 // casting away constness on our data, we instead reinterpret the CuBLAS
40 // functions as what they were clearly meant to be, and thus we can call
41 // the functions naturally.
42 //
43 // (The error is that input-only arrays are bound to parameter types
44 // "const T**" instead of the correct "const T* const*".)
45 extern "C" {
46 using getrs_S = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
47 const float* const*, int, const int*, float**,
48 int, int*, int);
49 using getrs_D = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
50 const double* const*, int, const int*, double**,
51 int, int*, int);
52 using getrs_C = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
53 const float2* const*, int, const int*, float2**,
54 int, int*, int);
55 using getrs_Z = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
56 const double2* const*, int, const int*,
57 double2**, int, int*, int);
58
59 using getri_S = cublasStatus_t(cublasContext*, int, const float* const*, int,
60 const int*, float**, int, int*, int);
61 using getri_D = cublasStatus_t(cublasContext*, int, const double* const*, int,
62 const int*, double**, int, int*, int);
63 using getri_C = cublasStatus_t(cublasContext*, int, const float2* const*, int,
64 const int*, float2**, int, int*, int);
65 using getri_Z = cublasStatus_t(cublasContext*, int, const double2* const*, int,
66 const int*, double2**, int, int*, int);
67
68 using matinv_S = cublasStatus_t(cublasContext*, int, const float* const*, int,
69 float**, int, int*, int);
70 using matinv_D = cublasStatus_t(cublasContext*, int, const double* const*, int,
71 double**, int, int*, int);
72 using matinv_C = cublasStatus_t(cublasContext*, int, const float2* const*, int,
73 float2**, int, int*, int);
74 using matinv_Z = cublasStatus_t(cublasContext*, int, const double2* const*, int,
75 double2**, int, int*, int);
76 }
77
78 namespace tensorflow {
79 namespace {
80
81 using se::cuda::ScopedActivateExecutorContext;
82
CopyHostToDevice(OpKernelContext * context,void * dst,const void * src,uint64 bytes)83 inline bool CopyHostToDevice(OpKernelContext* context, void* dst,
84 const void* src, uint64 bytes) {
85 auto stream = context->op_device_context()->stream();
86 se::DeviceMemoryBase wrapped_dst(dst);
87 return stream->ThenMemcpy(&wrapped_dst, src, bytes).ok();
88 }
89
90 // A set of initialized handles to the underlying Cuda libraries used by
91 // CudaSolver. We maintain one such set of handles per unique stream.
92 struct CudaSolverHandles {
CudaSolverHandlestensorflow::__anona8fbe2f90111::CudaSolverHandles93 explicit CudaSolverHandles(cudaStream_t stream) {
94 CHECK(cusolverDnCreate(&cusolver_dn_handle) == CUSOLVER_STATUS_SUCCESS)
95 << "Failed to create cuSolverDN instance.";
96 CHECK(cusolverDnSetStream(cusolver_dn_handle, stream) ==
97 CUSOLVER_STATUS_SUCCESS)
98 << "Failed to set cuSolverDN stream.";
99 CHECK(cublasCreate(&cublas_handle) == CUBLAS_STATUS_SUCCESS)
100 << "Failed to create cuBlas instance.";
101 CHECK(cublasSetStream(cublas_handle, stream) == CUBLAS_STATUS_SUCCESS)
102 << "Failed to set cuBlas stream.";
103 }
104
~CudaSolverHandlestensorflow::__anona8fbe2f90111::CudaSolverHandles105 ~CudaSolverHandles() {
106 CHECK(cublasDestroy(cublas_handle) == CUBLAS_STATUS_SUCCESS)
107 << "Failed to destroy cuBlas instance.";
108 CHECK(cusolverDnDestroy(cusolver_dn_handle) == CUSOLVER_STATUS_SUCCESS)
109 << "Failed to destroy cuSolverDN instance.";
110 }
111 cublasHandle_t cublas_handle;
112 cusolverDnHandle_t cusolver_dn_handle;
113 };
114
115 static mutex handle_map_mutex(LINKER_INITIALIZED);
116
117 using HandleMap =
118 std::unordered_map<cudaStream_t, std::unique_ptr<CudaSolverHandles>>;
119
120 // Returns a singleton map used for storing initialized handles for each unique
121 // cuda stream.
GetHandleMapSingleton()122 HandleMap* GetHandleMapSingleton() {
123 static HandleMap* cm = new HandleMap;
124 return cm;
125 }
126
127 } // namespace
128
129 #define TF_RETURN_IF_CUSOLVER_ERROR(expr) \
130 do { \
131 auto status = (expr); \
132 if (TF_PREDICT_FALSE(status != CUSOLVER_STATUS_SUCCESS)) { \
133 return errors::Internal( \
134 __FILE__, ":", __LINE__, \
135 ": cuSolverDN call failed with status =", status); \
136 } \
137 } while (0)
138
139 #define TF_RETURN_IF_CUBLAS_ERROR(expr) \
140 do { \
141 auto status = (expr); \
142 if (TF_PREDICT_FALSE(status != CUBLAS_STATUS_SUCCESS)) { \
143 return errors::Internal(__FILE__, ":", __LINE__, \
144 ": cuBlas call failed status = ", status); \
145 } \
146 } while (0)
147
CudaSolver(OpKernelContext * context)148 CudaSolver::CudaSolver(OpKernelContext* context) : context_(context) {
149 mutex_lock lock(handle_map_mutex);
150 const cudaStream_t* cu_stream_ptr = CHECK_NOTNULL(
151 reinterpret_cast<const cudaStream_t*>(context->op_device_context()
152 ->stream()
153 ->implementation()
154 ->GpuStreamMemberHack()));
155 cuda_stream_ = *cu_stream_ptr;
156 HandleMap* handle_map = CHECK_NOTNULL(GetHandleMapSingleton());
157 auto it = handle_map->find(cuda_stream_);
158 if (it == handle_map->end()) {
159 LOG(INFO) << "Creating CudaSolver handles for stream " << cuda_stream_;
160 // Previously unseen Cuda stream. Initialize a set of Cuda solver library
161 // handles for it.
162 std::unique_ptr<CudaSolverHandles> new_handles(
163 new CudaSolverHandles(cuda_stream_));
164 it =
165 handle_map->insert(std::make_pair(cuda_stream_, std::move(new_handles)))
166 .first;
167 }
168 cusolver_dn_handle_ = it->second->cusolver_dn_handle;
169 cublas_handle_ = it->second->cublas_handle;
170 }
171
~CudaSolver()172 CudaSolver::~CudaSolver() {
173 for (auto tensor_ref : scratch_tensor_refs_) {
174 tensor_ref.Unref();
175 }
176 }
177
178 // static
CheckLapackInfoAndDeleteSolverAsync(std::unique_ptr<CudaSolver> solver,const std::vector<DeviceLapackInfo> & dev_lapack_infos,std::function<void (const Status &,const std::vector<HostLapackInfo> &)> info_checker_callback)179 void CudaSolver::CheckLapackInfoAndDeleteSolverAsync(
180 std::unique_ptr<CudaSolver> solver,
181 const std::vector<DeviceLapackInfo>& dev_lapack_infos,
182 std::function<void(const Status&, const std::vector<HostLapackInfo>&)>
183 info_checker_callback) {
184 CHECK(info_checker_callback != nullptr);
185 std::vector<HostLapackInfo> host_lapack_infos;
186 if (dev_lapack_infos.empty()) {
187 info_checker_callback(Status::OK(), host_lapack_infos);
188 return;
189 }
190
191 // Launch memcpys to copy info back from the device to the host.
192 for (const auto& dev_lapack_info : dev_lapack_infos) {
193 bool success = true;
194 auto host_copy = dev_lapack_info.CopyToHost(&success);
195 OP_REQUIRES(
196 solver->context(), success,
197 errors::Internal(
198 "Failed to launch copy of dev_lapack_info to host, debug_info = ",
199 dev_lapack_info.debug_info()));
200 host_lapack_infos.push_back(std::move(host_copy));
201 }
202
203 // This callback checks that all batch items in all calls were processed
204 // successfully and passes status to the info_checker_callback accordingly.
205 auto* stream = solver->context()->op_device_context()->stream();
206 auto wrapped_info_checker_callback =
207 [stream](
208 CudaSolver* solver,
209 std::function<void(const Status&, const std::vector<HostLapackInfo>&)>
210 info_checker_callback,
211 std::vector<HostLapackInfo> host_lapack_infos) {
212 ScopedActivateExecutorContext scoped_activation{stream->parent()};
213 Status status;
214 for (const auto& host_lapack_info : host_lapack_infos) {
215 for (int i = 0; i < host_lapack_info.size() && status.ok(); ++i) {
216 const int info_value = host_lapack_info(i);
217 if (info_value != 0) {
218 status = errors::InvalidArgument(
219 "Got info = ", info_value, " for batch index ", i,
220 ", expected info = 0. Debug_info = ",
221 host_lapack_info.debug_info());
222 }
223 }
224 if (!status.ok()) {
225 break;
226 }
227 }
228 // Delete solver to release temp tensor refs.
229 delete solver;
230
231 // Delegate further error checking to provided functor.
232 info_checker_callback(status, host_lapack_infos);
233 };
234 // Note: An std::function cannot have unique_ptr arguments (it must be copy
235 // constructible and therefore so must its arguments). Therefore, we release
236 // solver into a raw pointer to be deleted at the end of
237 // wrapped_info_checker_callback.
238 // Release ownership of solver. It will be deleted in the cb callback.
239 auto solver_raw_ptr = solver.release();
240 auto cb =
241 std::bind(wrapped_info_checker_callback, solver_raw_ptr,
242 std::move(info_checker_callback), std::move(host_lapack_infos));
243
244 solver_raw_ptr->context()
245 ->device()
246 ->tensorflow_gpu_device_info()
247 ->event_mgr->ThenExecute(stream, std::move(cb));
248 }
249
250 // static
CheckLapackInfoAndDeleteSolverAsync(std::unique_ptr<CudaSolver> solver,const std::vector<DeviceLapackInfo> & dev_lapack_info,AsyncOpKernel::DoneCallback done)251 void CudaSolver::CheckLapackInfoAndDeleteSolverAsync(
252 std::unique_ptr<CudaSolver> solver,
253 const std::vector<DeviceLapackInfo>& dev_lapack_info,
254 AsyncOpKernel::DoneCallback done) {
255 OpKernelContext* context = solver->context();
256 auto wrapped_done = [context, done](
257 const Status& status,
258 const std::vector<HostLapackInfo>& /* unused */) {
259 if (done != nullptr) {
260 OP_REQUIRES_OK_ASYNC(context, status, done);
261 done();
262 } else {
263 OP_REQUIRES_OK(context, status);
264 }
265 };
266 CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_lapack_info,
267 wrapped_done);
268 }
269
270 // Allocates a temporary tensor. The CudaSolver object maintains a
271 // TensorReference to the underlying Tensor to prevent it from being deallocated
272 // prematurely.
allocate_scoped_tensor(DataType type,const TensorShape & shape,Tensor * out_temp)273 Status CudaSolver::allocate_scoped_tensor(DataType type,
274 const TensorShape& shape,
275 Tensor* out_temp) {
276 const Status status = context_->allocate_temp(type, shape, out_temp);
277 if (status.ok()) {
278 scratch_tensor_refs_.emplace_back(*out_temp);
279 }
280 return status;
281 }
282
forward_input_or_allocate_scoped_tensor(gtl::ArraySlice<int> candidate_input_indices,DataType type,const TensorShape & shape,Tensor * out_temp)283 Status CudaSolver::forward_input_or_allocate_scoped_tensor(
284 gtl::ArraySlice<int> candidate_input_indices, DataType type,
285 const TensorShape& shape, Tensor* out_temp) {
286 const Status status = context_->forward_input_or_allocate_temp(
287 candidate_input_indices, type, shape, out_temp);
288 if (status.ok()) {
289 scratch_tensor_refs_.emplace_back(*out_temp);
290 }
291 return status;
292 }
293
294 // Macro that specializes a solver method for all 4 standard
295 // numeric types.
296 #define TF_CALL_LAPACK_TYPES(m) \
297 m(float, S) m(double, D) m(std::complex<float>, C) m(std::complex<double>, Z)
298 #define TF_CALL_LAPACK_TYPES_NO_COMPLEX(m) m(float, S) m(double, D)
299
300 // Macros to construct cusolverDn method names.
301 #define DN_SOLVER_FN(method, type_prefix) cusolverDn##type_prefix##method
302 #define DN_SOLVER_NAME(method, type_prefix) "cusolverDn" #type_prefix #method
303 #define DN_BUFSIZE_FN(method, type_prefix) \
304 cusolverDn##type_prefix##method##_bufferSize
305
306 // Macros to construct cublas method names.
307 #define BLAS_SOLVER_FN(method, type_prefix) cublas##type_prefix##method
308 #define BLAS_SOLVER_NAME(method, type_prefix) "cublas" #type_prefix #method
309
310 //=============================================================================
311 // Wrappers of cuSolverDN computational methods begin here.
312 //
313 // WARNING to implementers: The function signatures listed in the online docs
314 // are sometimes inaccurate, e.g., are missing 'const' on pointers
315 // to immutable arguments, while the actual headers have them as expected.
316 // Check the actual declarations in the cusolver_api.h header file.
317 //
318 // NOTE: The cuSolver functions called below appear not to be threadsafe.
319 // so we put a global lock around the calls. Since these functions only put a
320 // kernel on the shared stream, it is not a big performance hit.
321 // TODO(rmlarsen): Investigate if the locking is still needed in Cuda 9.
322 //=============================================================================
323
324 template <typename Scalar, typename SolverFnT>
GeamImpl(SolverFnT solver,cublasHandle_t cublas_handle,cublasOperation_t transa,cublasOperation_t transb,int m,int n,const Scalar * alpha,const Scalar * A,int lda,const Scalar * beta,const Scalar * B,int ldb,Scalar * C,int ldc)325 static inline Status GeamImpl(SolverFnT solver, cublasHandle_t cublas_handle,
326 cublasOperation_t transa,
327 cublasOperation_t transb, int m, int n,
328 const Scalar* alpha, /* host or device pointer */
329 const Scalar* A, int lda,
330 const Scalar* beta, /* host or device pointer */
331 const Scalar* B, int ldb, Scalar* C, int ldc) {
332 mutex_lock lock(handle_map_mutex);
333 using CudaScalar = typename CUDAComplexT<Scalar>::type;
334 TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, transa, transb, m, n,
335 reinterpret_cast<const CudaScalar*>(alpha),
336 reinterpret_cast<const CudaScalar*>(A), lda,
337 reinterpret_cast<const CudaScalar*>(beta),
338 reinterpret_cast<const CudaScalar*>(B), ldb,
339 reinterpret_cast<CudaScalar*>(C), ldc));
340 return Status::OK();
341 }
342
343 #define GEAM_INSTANCE(Scalar, type_prefix) \
344 template <> \
345 Status CudaSolver::Geam<Scalar>( \
346 cublasOperation_t transa, cublasOperation_t transb, int m, int n, \
347 const Scalar* alpha, /* host or device pointer */ \
348 const Scalar* A, int lda, \
349 const Scalar* beta, /* host or device pointer */ \
350 const Scalar* B, int ldb, Scalar* C, int ldc) const { \
351 return GeamImpl(BLAS_SOLVER_FN(geam, type_prefix), cublas_handle_, transa, \
352 transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); \
353 }
354
355 TF_CALL_LAPACK_TYPES(GEAM_INSTANCE);
356
357 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
PotrfImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cublasFillMode_t uplo,int n,Scalar * A,int lda,int * dev_lapack_info)358 static inline Status PotrfImpl(BufSizeFnT bufsize, SolverFnT solver,
359 CudaSolver* cuda_solver,
360 OpKernelContext* context,
361 cusolverDnHandle_t cusolver_dn_handle,
362 cublasFillMode_t uplo, int n, Scalar* A, int lda,
363 int* dev_lapack_info) {
364 mutex_lock lock(handle_map_mutex);
365 /* Get amount of workspace memory required. */
366 int lwork;
367 TF_RETURN_IF_CUSOLVER_ERROR(
368 bufsize(cusolver_dn_handle, uplo, n, CUDAComplex(A), lda, &lwork));
369 /* Allocate device memory for workspace. */
370 auto dev_workspace =
371 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
372 /* Launch the solver kernel. */
373 TF_RETURN_IF_CUSOLVER_ERROR(solver(
374 cusolver_dn_handle, uplo, n, CUDAComplex(A), lda,
375 CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));
376 return Status::OK();
377 }
378
379 #define POTRF_INSTANCE(Scalar, type_prefix) \
380 template <> \
381 Status CudaSolver::Potrf<Scalar>(cublasFillMode_t uplo, int n, Scalar* A, \
382 int lda, int* dev_lapack_info) { \
383 return PotrfImpl(DN_BUFSIZE_FN(potrf, type_prefix), \
384 DN_SOLVER_FN(potrf, type_prefix), this, context_, \
385 cusolver_dn_handle_, uplo, n, A, lda, dev_lapack_info); \
386 }
387
388 TF_CALL_LAPACK_TYPES(POTRF_INSTANCE);
389
390 #if CUDA_VERSION >= 9020
391 template <typename Scalar, typename SolverFnT>
PotrfBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cublasFillMode_t uplo,int n,const Scalar * const host_a_dev_ptrs[],int lda,DeviceLapackInfo * dev_lapack_info,int batch_size)392 static inline Status PotrfBatchedImpl(
393 SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
394 cusolverDnHandle_t cusolver_dn_handle, cublasFillMode_t uplo, int n,
395 const Scalar* const host_a_dev_ptrs[], int lda,
396 DeviceLapackInfo* dev_lapack_info, int batch_size) {
397 mutex_lock lock(handle_map_mutex);
398 using CudaScalar = typename CUDAComplexT<Scalar>::type;
399 ScratchSpace<uint8> dev_a_dev_ptrs =
400 cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
401 /* on_host */ false);
402 if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
403 host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
404 return errors::Internal("PotrfBatched: failed to copy pointers to device");
405 }
406 TF_RETURN_IF_CUSOLVER_ERROR(
407 solver(cusolver_dn_handle, uplo, n,
408 reinterpret_cast<CudaScalar**>(dev_a_dev_ptrs.mutable_data()), lda,
409 dev_lapack_info->mutable_data(), batch_size));
410 return Status::OK();
411 }
412
413 #define POTRF_BATCHED_INSTANCE(Scalar, type_prefix) \
414 template <> \
415 Status CudaSolver::PotrfBatched( \
416 cublasFillMode_t uplo, int n, const Scalar* const host_a_dev_ptrs[], \
417 int lda, DeviceLapackInfo* dev_lapack_info, int batch_size) { \
418 return PotrfBatchedImpl(DN_SOLVER_FN(potrfBatched, type_prefix), this, \
419 context_, cusolver_dn_handle_, uplo, n, \
420 host_a_dev_ptrs, lda, dev_lapack_info, \
421 batch_size); \
422 }
423
424 TF_CALL_LAPACK_TYPES(POTRF_BATCHED_INSTANCE);
425 #endif // CUDA_VERSION >= 9020
426
427 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
GetrfImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,int m,int n,Scalar * A,int lda,int * dev_pivots,int * dev_lapack_info)428 static inline Status GetrfImpl(BufSizeFnT bufsize, SolverFnT solver,
429 CudaSolver* cuda_solver,
430 OpKernelContext* context,
431 cusolverDnHandle_t cusolver_dn_handle, int m,
432 int n, Scalar* A, int lda, int* dev_pivots,
433 int* dev_lapack_info) {
434 mutex_lock lock(handle_map_mutex);
435 /* Get amount of workspace memory required. */
436 int lwork;
437 TF_RETURN_IF_CUSOLVER_ERROR(
438 bufsize(cusolver_dn_handle, m, n, CUDAComplex(A), lda, &lwork));
439 /* Allocate device memory for workspace. */
440 auto dev_workspace =
441 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
442 /* Launch the solver kernel. */
443 TF_RETURN_IF_CUSOLVER_ERROR(solver(
444 cusolver_dn_handle, m, n, CUDAComplex(A), lda,
445 CUDAComplex(dev_workspace.mutable_data()), dev_pivots, dev_lapack_info));
446 return Status::OK();
447 }
448
449 #define GETRF_INSTANCE(Scalar, type_prefix) \
450 template <> \
451 Status CudaSolver::Getrf<Scalar>(int m, int n, Scalar* A, int lda, \
452 int* dev_pivots, int* dev_lapack_info) { \
453 return GetrfImpl(DN_BUFSIZE_FN(getrf, type_prefix), \
454 DN_SOLVER_FN(getrf, type_prefix), this, context_, \
455 cusolver_dn_handle_, m, n, A, lda, dev_pivots, \
456 dev_lapack_info); \
457 }
458
459 TF_CALL_LAPACK_TYPES(GETRF_INSTANCE);
460
461 template <typename Scalar, typename SolverFnT>
GetrsImpl(SolverFnT solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cublasOperation_t trans,int n,int nrhs,const Scalar * A,int lda,const int * pivots,Scalar * B,int ldb,int * dev_lapack_info)462 static inline Status GetrsImpl(SolverFnT solver, OpKernelContext* context,
463 cusolverDnHandle_t cusolver_dn_handle,
464 cublasOperation_t trans, int n, int nrhs,
465 const Scalar* A, int lda, const int* pivots,
466 Scalar* B, int ldb, int* dev_lapack_info) {
467 mutex_lock lock(handle_map_mutex);
468 /* Launch the solver kernel. */
469 TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, trans, n, nrhs,
470 CUDAComplex(A), lda, pivots,
471 CUDAComplex(B), ldb, dev_lapack_info));
472 return Status::OK();
473 }
474
475 #define GETRS_INSTANCE(Scalar, type_prefix) \
476 template <> \
477 Status CudaSolver::Getrs<Scalar>( \
478 cublasOperation_t trans, int n, int nrhs, const Scalar* A, int lda, \
479 const int* pivots, Scalar* B, int ldb, int* dev_lapack_info) const { \
480 return GetrsImpl(DN_SOLVER_FN(getrs, type_prefix), context_, \
481 cusolver_dn_handle_, trans, n, nrhs, A, lda, pivots, B, \
482 ldb, dev_lapack_info); \
483 }
484
485 TF_CALL_LAPACK_TYPES(GETRS_INSTANCE);
486
487 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
GeqrfImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,int m,int n,Scalar * A,int lda,Scalar * tau,int * dev_lapack_info)488 static inline Status GeqrfImpl(BufSizeFnT bufsize, SolverFnT solver,
489 CudaSolver* cuda_solver,
490 OpKernelContext* context,
491 cusolverDnHandle_t cusolver_dn_handle, int m,
492 int n, Scalar* A, int lda, Scalar* tau,
493 int* dev_lapack_info) {
494 mutex_lock lock(handle_map_mutex);
495 /* Get amount of workspace memory required. */
496 int lwork;
497 TF_RETURN_IF_CUSOLVER_ERROR(
498 bufsize(cusolver_dn_handle, m, n, CUDAComplex(A), lda, &lwork));
499 /* Allocate device memory for workspace. */
500 auto dev_workspace =
501 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
502 /* Launch the solver kernel. */
503 TF_RETURN_IF_CUSOLVER_ERROR(solver(
504 cusolver_dn_handle, m, n, CUDAComplex(A), lda, CUDAComplex(tau),
505 CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));
506 return Status::OK();
507 }
508
509 #define GEQRF_INSTANCE(Scalar, type_prefix) \
510 template <> \
511 Status CudaSolver::Geqrf<Scalar>(int m, int n, Scalar* A, int lda, \
512 Scalar* tau, int* dev_lapack_info) { \
513 return GeqrfImpl(DN_BUFSIZE_FN(geqrf, type_prefix), \
514 DN_SOLVER_FN(geqrf, type_prefix), this, context_, \
515 cusolver_dn_handle_, m, n, A, lda, tau, dev_lapack_info); \
516 }
517
518 TF_CALL_LAPACK_TYPES(GEQRF_INSTANCE);
519
520 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
UnmqrImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cublasSideMode_t side,cublasOperation_t trans,int m,int n,int k,const Scalar * dev_a,int lda,const Scalar * dev_tau,Scalar * dev_c,int ldc,int * dev_lapack_info)521 static inline Status UnmqrImpl(BufSizeFnT bufsize, SolverFnT solver,
522 CudaSolver* cuda_solver,
523 OpKernelContext* context,
524 cusolverDnHandle_t cusolver_dn_handle,
525 cublasSideMode_t side, cublasOperation_t trans,
526 int m, int n, int k, const Scalar* dev_a,
527 int lda, const Scalar* dev_tau, Scalar* dev_c,
528 int ldc, int* dev_lapack_info) {
529 mutex_lock lock(handle_map_mutex);
530 /* Get amount of workspace memory required. */
531 int lwork;
532 TF_RETURN_IF_CUSOLVER_ERROR(
533 bufsize(cusolver_dn_handle, side, trans, m, n, k, CUDAComplex(dev_a), lda,
534 CUDAComplex(dev_tau), CUDAComplex(dev_c), ldc, &lwork));
535 /* Allocate device memory for workspace. */
536 auto dev_workspace =
537 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
538 /* Launch the solver kernel. */
539 TF_RETURN_IF_CUSOLVER_ERROR(solver(
540 cusolver_dn_handle, side, trans, m, n, k, CUDAComplex(dev_a), lda,
541 CUDAComplex(dev_tau), CUDAComplex(dev_c), ldc,
542 CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));
543 return Status::OK();
544 }
545
546 // Unfortunately the LAPACK function name differs for the real and complex case
547 // (complex ones are prefixed with "UN" for "unitary"), so we instantiate each
548 // one separately.
549 #define UNMQR_INSTANCE(Scalar, function_prefix, type_prefix) \
550 template <> \
551 Status CudaSolver::Unmqr(cublasSideMode_t side, cublasOperation_t trans, \
552 int m, int n, int k, const Scalar* dev_a, int lda, \
553 const Scalar* dev_tau, Scalar* dev_c, int ldc, \
554 int* dev_lapack_info) { \
555 return UnmqrImpl(DN_BUFSIZE_FN(function_prefix##mqr, type_prefix), \
556 DN_SOLVER_FN(function_prefix##mqr, type_prefix), this, \
557 context_, cusolver_dn_handle_, side, trans, m, n, k, \
558 dev_a, lda, dev_tau, dev_c, ldc, dev_lapack_info); \
559 }
560
561 UNMQR_INSTANCE(float, or, S);
562 UNMQR_INSTANCE(double, or, D);
563 UNMQR_INSTANCE(complex64, un, C);
564 UNMQR_INSTANCE(complex128, un, Z);
565
566 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
UngqrImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,int m,int n,int k,Scalar * dev_a,int lda,const Scalar * dev_tau,int * dev_lapack_info)567 static inline Status UngqrImpl(BufSizeFnT bufsize, SolverFnT solver,
568 CudaSolver* cuda_solver,
569 OpKernelContext* context,
570 cusolverDnHandle_t cusolver_dn_handle, int m,
571 int n, int k, Scalar* dev_a, int lda,
572 const Scalar* dev_tau, int* dev_lapack_info) {
573 mutex_lock lock(handle_map_mutex);
574 /* Get amount of workspace memory required. */
575 int lwork;
576 TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, k,
577 CUDAComplex(dev_a), lda,
578 CUDAComplex(dev_tau), &lwork));
579 /* Allocate device memory for workspace. */
580 auto dev_workspace =
581 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
582 /* Launch the solver kernel. */
583 TF_RETURN_IF_CUSOLVER_ERROR(
584 solver(cusolver_dn_handle, m, n, k, CUDAComplex(dev_a), lda,
585 CUDAComplex(dev_tau), CUDAComplex(dev_workspace.mutable_data()),
586 lwork, dev_lapack_info));
587 return Status::OK();
588 }
589
590 #define UNGQR_INSTANCE(Scalar, function_prefix, type_prefix) \
591 template <> \
592 Status CudaSolver::Ungqr(int m, int n, int k, Scalar* dev_a, int lda, \
593 const Scalar* dev_tau, int* dev_lapack_info) { \
594 return UngqrImpl(DN_BUFSIZE_FN(function_prefix##gqr, type_prefix), \
595 DN_SOLVER_FN(function_prefix##gqr, type_prefix), this, \
596 context_, cusolver_dn_handle_, m, n, k, dev_a, lda, \
597 dev_tau, dev_lapack_info); \
598 }
599
600 UNGQR_INSTANCE(float, or, S);
601 UNGQR_INSTANCE(double, or, D);
602 UNGQR_INSTANCE(complex64, un, C);
603 UNGQR_INSTANCE(complex128, un, Z);
604
605 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
HeevdImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,Scalar * dev_A,int lda,typename Eigen::NumTraits<Scalar>::Real * dev_W,int * dev_lapack_info)606 static inline Status HeevdImpl(BufSizeFnT bufsize, SolverFnT solver,
607 CudaSolver* cuda_solver,
608 OpKernelContext* context,
609 cusolverDnHandle_t cusolver_dn_handle,
610 cusolverEigMode_t jobz, cublasFillMode_t uplo,
611 int n, Scalar* dev_A, int lda,
612 typename Eigen::NumTraits<Scalar>::Real* dev_W,
613 int* dev_lapack_info) {
614 mutex_lock lock(handle_map_mutex);
615 /* Get amount of workspace memory required. */
616 int lwork;
617 TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, jobz, uplo, n,
618 CUDAComplex(dev_A), lda,
619 CUDAComplex(dev_W), &lwork));
620 /* Allocate device memory for workspace. */
621 auto dev_workspace =
622 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
623 /* Launch the solver kernel. */
624 TF_RETURN_IF_CUSOLVER_ERROR(
625 solver(cusolver_dn_handle, jobz, uplo, n, CUDAComplex(dev_A), lda,
626 CUDAComplex(dev_W), CUDAComplex(dev_workspace.mutable_data()),
627 lwork, dev_lapack_info));
628 return Status::OK();
629 }
630
631 #define HEEVD_INSTANCE(Scalar, function_prefix, type_prefix) \
632 template <> \
633 Status CudaSolver::Heevd(cusolverEigMode_t jobz, cublasFillMode_t uplo, \
634 int n, Scalar* dev_A, int lda, \
635 typename Eigen::NumTraits<Scalar>::Real* dev_W, \
636 int* dev_lapack_info) { \
637 return HeevdImpl(DN_BUFSIZE_FN(function_prefix##evd, type_prefix), \
638 DN_SOLVER_FN(function_prefix##evd, type_prefix), this, \
639 context_, cusolver_dn_handle_, jobz, uplo, n, dev_A, lda, \
640 dev_W, dev_lapack_info); \
641 }
642
643 HEEVD_INSTANCE(float, sy, S);
644 HEEVD_INSTANCE(double, sy, D);
645 HEEVD_INSTANCE(complex64, he, C);
646 HEEVD_INSTANCE(complex128, he, Z);
647
648 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
GesvdImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,signed char jobu,signed char jobvt,int m,int n,Scalar * A,int lda,Scalar * S,Scalar * U,int ldu,Scalar * VT,int ldvt,int * dev_lapack_info)649 static inline Status GesvdImpl(
650 BufSizeFnT bufsize, SolverFnT solver, CudaSolver* cuda_solver,
651 OpKernelContext* context, cusolverDnHandle_t cusolver_dn_handle,
652 signed char jobu, signed char jobvt, int m, int n, Scalar* A, int lda,
653 Scalar* S, Scalar* U, int ldu, Scalar* VT, int ldvt, int* dev_lapack_info) {
654 mutex_lock lock(handle_map_mutex);
655 /* Get amount of workspace memory required. */
656 int lwork;
657 TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, &lwork));
658 /* Allocate device memory for workspace. */
659 auto dev_workspace =
660 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
661 TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, jobu, jobvt, m, n,
662 CUDAComplex(A), lda, S, CUDAComplex(U),
663 ldu, CUDAComplex(VT), ldvt,
664 CUDAComplex(dev_workspace.mutable_data()),
665 lwork, nullptr, dev_lapack_info));
666 return Status::OK();
667 }
668
669 #define GESVD_INSTANCE(Scalar, type_prefix) \
670 template <> \
671 Status CudaSolver::Gesvd<Scalar>( \
672 signed char jobu, signed char jobvt, int m, int n, Scalar* dev_A, \
673 int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT, \
674 int ldvt, int* dev_lapack_info) { \
675 return GesvdImpl(DN_BUFSIZE_FN(gesvd, type_prefix), \
676 DN_SOLVER_FN(gesvd, type_prefix), this, context_, \
677 cusolver_dn_handle_, jobu, jobvt, m, n, dev_A, lda, \
678 dev_S, dev_U, ldu, dev_VT, ldvt, dev_lapack_info); \
679 }
680
681 TF_CALL_LAPACK_TYPES_NO_COMPLEX(GESVD_INSTANCE);
682
683 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
GesvdjBatchedImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cusolverEigMode_t jobz,int m,int n,Scalar * A,int lda,Scalar * S,Scalar * U,int ldu,Scalar * V,int ldv,int * dev_lapack_info,int batch_size)684 static inline Status GesvdjBatchedImpl(BufSizeFnT bufsize, SolverFnT solver,
685 CudaSolver* cuda_solver,
686 OpKernelContext* context,
687 cusolverDnHandle_t cusolver_dn_handle,
688 cusolverEigMode_t jobz, int m, int n,
689 Scalar* A, int lda, Scalar* S, Scalar* U,
690 int ldu, Scalar* V, int ldv,
691 int* dev_lapack_info, int batch_size) {
692 mutex_lock lock(handle_map_mutex);
693 /* Get amount of workspace memory required. */
694 int lwork;
695 /* Default parameters for gesvdj and gesvdjBatched. */
696 gesvdjInfo_t svdj_info;
697 TF_RETURN_IF_CUSOLVER_ERROR(cusolverDnCreateGesvdjInfo(&svdj_info));
698 TF_RETURN_IF_CUSOLVER_ERROR(bufsize(
699 cusolver_dn_handle, jobz, m, n, CUDAComplex(A), lda, S, CUDAComplex(U),
700 ldu, CUDAComplex(V), ldv, &lwork, svdj_info, batch_size));
701 /* Allocate device memory for workspace. */
702 auto dev_workspace =
703 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
704 TF_RETURN_IF_CUSOLVER_ERROR(solver(
705 cusolver_dn_handle, jobz, m, n, CUDAComplex(A), lda, S, CUDAComplex(U),
706 ldu, CUDAComplex(V), ldv, CUDAComplex(dev_workspace.mutable_data()),
707 lwork, dev_lapack_info, svdj_info, batch_size));
708 TF_RETURN_IF_CUSOLVER_ERROR(cusolverDnDestroyGesvdjInfo(svdj_info));
709 return Status::OK();
710 }
711
712 #define GESVDJBATCHED_INSTANCE(Scalar, type_prefix) \
713 template <> \
714 Status CudaSolver::GesvdjBatched<Scalar>( \
715 cusolverEigMode_t jobz, int m, int n, Scalar* dev_A, int lda, \
716 Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_V, int ldv, \
717 int* dev_lapack_info, int batch_size) { \
718 return GesvdjBatchedImpl(DN_BUFSIZE_FN(gesvdjBatched, type_prefix), \
719 DN_SOLVER_FN(gesvdjBatched, type_prefix), this, \
720 context_, cusolver_dn_handle_, jobz, m, n, dev_A, \
721 lda, dev_S, dev_U, ldu, dev_V, ldv, \
722 dev_lapack_info, batch_size); \
723 }
724
725 TF_CALL_LAPACK_TYPES_NO_COMPLEX(GESVDJBATCHED_INSTANCE);
726
727 //=============================================================================
728 // Wrappers of cuBlas computational methods begin here.
729 //
730 // WARNING to implementers: The function signatures listed in the online docs
731 // are sometimes inaccurate, e.g., are missing 'const' on pointers
732 // to immutable arguments, while the actual headers have them as expected.
733 // Check the actual declarations in the cublas_api.h header file.
734 //=============================================================================
735 template <typename Scalar, typename SolverFnT>
GetrfBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,int n,const Scalar * const host_a_dev_ptrs[],int lda,int * dev_pivots,DeviceLapackInfo * dev_lapack_info,int batch_size)736 static inline Status GetrfBatchedImpl(SolverFnT solver, CudaSolver* cuda_solver,
737 OpKernelContext* context,
738 cublasHandle_t cublas_handle, int n,
739 const Scalar* const host_a_dev_ptrs[],
740 int lda, int* dev_pivots,
741 DeviceLapackInfo* dev_lapack_info,
742 int batch_size) {
743 mutex_lock lock(handle_map_mutex);
744 using CudaScalar = typename CUDAComplexT<Scalar>::type;
745 ScratchSpace<uint8> dev_a_dev_ptrs =
746 cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
747 /* on_host */ false);
748 if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
749 host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
750 return errors::Internal("GetrfBatched: failed to copy pointers to device");
751 }
752 TF_RETURN_IF_CUBLAS_ERROR(
753 solver(cublas_handle, n,
754 reinterpret_cast<CudaScalar**>(dev_a_dev_ptrs.mutable_data()), lda,
755 dev_pivots, dev_lapack_info->mutable_data(), batch_size));
756 return Status::OK();
757 }
758
759 #define GETRF_BATCHED_INSTANCE(Scalar, type_prefix) \
760 template <> \
761 Status CudaSolver::GetrfBatched( \
762 int n, const Scalar* const host_a_dev_ptrs[], int lda, int* dev_pivots, \
763 DeviceLapackInfo* dev_lapack_info, int batch_size) { \
764 return GetrfBatchedImpl(BLAS_SOLVER_FN(getrfBatched, type_prefix), this, \
765 context_, cublas_handle_, n, host_a_dev_ptrs, lda, \
766 dev_pivots, dev_lapack_info, batch_size); \
767 }
768
769 TF_CALL_LAPACK_TYPES(GETRF_BATCHED_INSTANCE);
770
771 template <typename Scalar, typename SolverFnT>
GetrsBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,cublasOperation_t trans,int n,int nrhs,const Scalar * const host_a_dev_ptrs[],int lda,const int * dev_pivots,const Scalar * const host_b_dev_ptrs[],int ldb,int * host_lapack_info,int batch_size)772 static inline Status GetrsBatchedImpl(
773 SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
774 cublasHandle_t cublas_handle, cublasOperation_t trans, int n, int nrhs,
775 const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots,
776 const Scalar* const host_b_dev_ptrs[], int ldb, int* host_lapack_info,
777 int batch_size) {
778 mutex_lock lock(handle_map_mutex);
779 using CudaScalar = typename CUDAComplexT<Scalar>::type;
780 ScratchSpace<uint8> dev_a_dev_ptrs =
781 cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
782 /* on_host */ false);
783 ScratchSpace<uint8> dev_b_dev_ptrs =
784 cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
785 /* on_host */ false);
786 if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
787 host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
788 return errors::Internal("GetrsBatched: failed to copy pointers to device");
789 }
790 if (!CopyHostToDevice(context, dev_b_dev_ptrs.mutable_data() /* dest */,
791 host_b_dev_ptrs /* source */, dev_b_dev_ptrs.bytes())) {
792 return errors::Internal("GetrsBatched: failed to copy pointers to device");
793 }
794 TF_RETURN_IF_CUBLAS_ERROR(solver(
795 cublas_handle, trans, n, nrhs,
796 reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()), lda,
797 dev_pivots, reinterpret_cast<CudaScalar**>(dev_b_dev_ptrs.mutable_data()),
798 ldb, host_lapack_info, batch_size));
799 return Status::OK();
800 }
801
802 #define GETRS_BATCHED_INSTANCE(Scalar, type_prefix) \
803 template <> \
804 Status CudaSolver::GetrsBatched( \
805 cublasOperation_t trans, int n, int nrhs, \
806 const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots, \
807 const Scalar* const host_b_dev_ptrs[], int ldb, int* host_lapack_info, \
808 int batch_size) { \
809 return GetrsBatchedImpl(reinterpret_cast<getrs_##type_prefix*>( \
810 BLAS_SOLVER_FN(getrsBatched, type_prefix)), \
811 this, context_, cublas_handle_, trans, n, nrhs, \
812 host_a_dev_ptrs, lda, dev_pivots, host_b_dev_ptrs, \
813 ldb, host_lapack_info, batch_size); \
814 }
815
816 TF_CALL_LAPACK_TYPES(GETRS_BATCHED_INSTANCE);
817
818 template <typename Scalar, typename SolverFnT>
GetriBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,int n,const Scalar * const host_a_dev_ptrs[],int lda,const int * dev_pivots,const Scalar * const host_a_inv_dev_ptrs[],int ldainv,DeviceLapackInfo * dev_lapack_info,int batch_size)819 static inline Status GetriBatchedImpl(
820 SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
821 cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[],
822 int lda, const int* dev_pivots, const Scalar* const host_a_inv_dev_ptrs[],
823 int ldainv, DeviceLapackInfo* dev_lapack_info, int batch_size) {
824 mutex_lock lock(handle_map_mutex);
825 using CudaScalar = typename CUDAComplexT<Scalar>::type;
826 ScratchSpace<uint8> dev_a_dev_ptrs =
827 cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
828 /* on_host */ false);
829 ScratchSpace<uint8> dev_a_inv_dev_ptrs = cuda_solver->GetScratchSpace<uint8>(
830 sizeof(CudaScalar*) * batch_size, "", /* on_host */ false);
831 if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
832 host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes()) ||
833 !CopyHostToDevice(context, dev_a_inv_dev_ptrs.mutable_data(),
834 host_a_inv_dev_ptrs, dev_a_inv_dev_ptrs.bytes())) {
835 return errors::Internal("GetriBatched: failed to copy pointers to device");
836 }
837 TF_RETURN_IF_CUBLAS_ERROR(
838 solver(cublas_handle, n,
839 reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()),
840 lda, dev_pivots,
841 reinterpret_cast<CudaScalar**>(dev_a_inv_dev_ptrs.mutable_data()),
842 ldainv, dev_lapack_info->mutable_data(), batch_size));
843 return Status::OK();
844 }
845
846 #define GETRI_BATCHED_INSTANCE(Scalar, type_prefix) \
847 template <> \
848 Status CudaSolver::GetriBatched( \
849 int n, const Scalar* const host_a_dev_ptrs[], int lda, \
850 const int* dev_pivots, const Scalar* const host_a_inv_dev_ptrs[], \
851 int ldainv, DeviceLapackInfo* dev_lapack_info, int batch_size) { \
852 return GetriBatchedImpl( \
853 reinterpret_cast<getri_##type_prefix*>( \
854 BLAS_SOLVER_FN(getriBatched, type_prefix)), \
855 this, context_, cublas_handle_, n, host_a_dev_ptrs, lda, dev_pivots, \
856 host_a_inv_dev_ptrs, ldainv, dev_lapack_info, batch_size); \
857 }
858
859 TF_CALL_LAPACK_TYPES(GETRI_BATCHED_INSTANCE);
860
861 template <typename Scalar, typename SolverFnT>
MatInvBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,int n,const Scalar * const host_a_dev_ptrs[],int lda,const Scalar * const host_a_inv_dev_ptrs[],int ldainv,DeviceLapackInfo * dev_lapack_info,int batch_size)862 static inline Status MatInvBatchedImpl(
863 SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
864 cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[],
865 int lda, const Scalar* const host_a_inv_dev_ptrs[], int ldainv,
866 DeviceLapackInfo* dev_lapack_info, int batch_size) {
867 mutex_lock lock(handle_map_mutex);
868 using CudaScalar = typename CUDAComplexT<Scalar>::type;
869 ScratchSpace<uint8> dev_a_dev_ptrs =
870 cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
871 /* on_host */ false);
872 ScratchSpace<uint8> dev_a_inv_dev_ptrs = cuda_solver->GetScratchSpace<uint8>(
873 sizeof(CudaScalar*) * batch_size, "", /* on_host */ false);
874 if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
875 host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes()) ||
876 !CopyHostToDevice(context, dev_a_inv_dev_ptrs.mutable_data(),
877 host_a_inv_dev_ptrs, dev_a_inv_dev_ptrs.bytes())) {
878 return errors::Internal("MatInvBatched: failed to copy pointers to device");
879 }
880 TF_RETURN_IF_CUBLAS_ERROR(solver(
881 cublas_handle, n,
882 reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()), lda,
883 reinterpret_cast<CudaScalar**>(dev_a_inv_dev_ptrs.mutable_data()), ldainv,
884 dev_lapack_info->mutable_data(), batch_size));
885 return Status::OK();
886 }
887
888 #define MATINV_BATCHED_INSTANCE(Scalar, type_prefix) \
889 template <> \
890 Status CudaSolver::MatInvBatched( \
891 int n, const Scalar* const host_a_dev_ptrs[], int lda, \
892 const Scalar* const host_a_inv_dev_ptrs[], int ldainv, \
893 DeviceLapackInfo* dev_lapack_info, int batch_size) { \
894 return MatInvBatchedImpl(reinterpret_cast<matinv_##type_prefix*>( \
895 BLAS_SOLVER_FN(matinvBatched, type_prefix)), \
896 this, context_, cublas_handle_, n, \
897 host_a_dev_ptrs, lda, host_a_inv_dev_ptrs, \
898 ldainv, dev_lapack_info, batch_size); \
899 }
900
901 TF_CALL_LAPACK_TYPES(MATINV_BATCHED_INSTANCE);
902
903 template <typename Scalar, typename SolverFnT>
TrsmImpl(SolverFnT solver,cublasHandle_t cublas_handle,cublasSideMode_t side,cublasFillMode_t uplo,cublasOperation_t trans,cublasDiagType_t diag,int m,int n,const Scalar * alpha,const Scalar * A,int lda,Scalar * B,int ldb)904 static inline Status TrsmImpl(SolverFnT solver, cublasHandle_t cublas_handle,
905 cublasSideMode_t side, cublasFillMode_t uplo,
906 cublasOperation_t trans, cublasDiagType_t diag,
907 int m, int n,
908 const Scalar* alpha, /* host or device pointer */
909 const Scalar* A, int lda, Scalar* B, int ldb) {
910 mutex_lock lock(handle_map_mutex);
911 using CudaScalar = typename CUDAComplexT<Scalar>::type;
912 TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, side, uplo, trans, diag, m, n,
913 reinterpret_cast<const CudaScalar*>(alpha),
914 reinterpret_cast<const CudaScalar*>(A), lda,
915 reinterpret_cast<CudaScalar*>(B), ldb));
916 return Status::OK();
917 }
918
919 #define TRSM_INSTANCE(Scalar, type_prefix) \
920 template <> \
921 Status CudaSolver::Trsm<Scalar>( \
922 cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, \
923 cublasDiagType_t diag, int m, int n, \
924 const Scalar* alpha, /* host or device pointer */ \
925 const Scalar* A, int lda, Scalar* B, int ldb) { \
926 return TrsmImpl(BLAS_SOLVER_FN(trsm, type_prefix), cublas_handle_, side, \
927 uplo, trans, diag, m, n, alpha, A, lda, B, ldb); \
928 }
929
930 TF_CALL_LAPACK_TYPES(TRSM_INSTANCE);
931
932 template <typename Scalar, typename SolverFnT>
TrsvImpl(SolverFnT solver,cublasHandle_t cublas_handle,cublasFillMode_t uplo,cublasOperation_t trans,cublasDiagType_t diag,int n,const Scalar * A,int lda,Scalar * x,int incx)933 static inline Status TrsvImpl(SolverFnT solver, cublasHandle_t cublas_handle,
934 cublasFillMode_t uplo, cublasOperation_t trans,
935 cublasDiagType_t diag, int n, const Scalar* A,
936 int lda, Scalar* x, int incx) {
937 mutex_lock lock(handle_map_mutex);
938 using CudaScalar = typename CUDAComplexT<Scalar>::type;
939 TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, uplo, trans, diag, n,
940 reinterpret_cast<const CudaScalar*>(A), lda,
941 reinterpret_cast<CudaScalar*>(x), incx));
942 return Status::OK();
943 }
944
945 #define TRSV_INSTANCE(Scalar, type_prefix) \
946 template <> \
947 Status CudaSolver::Trsv<Scalar>( \
948 cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, \
949 int n, const Scalar* A, int lda, Scalar* x, int incx) { \
950 return TrsvImpl(BLAS_SOLVER_FN(trsv, type_prefix), cublas_handle_, uplo, \
951 trans, diag, n, A, lda, x, incx); \
952 }
953
954 TF_CALL_LAPACK_TYPES(TRSV_INSTANCE);
955
956 template <typename Scalar, typename SolverFnT>
TrsmBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,cublasSideMode_t side,cublasFillMode_t uplo,cublasOperation_t trans,cublasDiagType_t diag,int m,int n,const Scalar * alpha,const Scalar * const host_a_dev_ptrs[],int lda,Scalar * host_b_dev_ptrs[],int ldb,int batch_size)957 static inline Status TrsmBatchedImpl(
958 SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
959 cublasHandle_t cublas_handle, cublasSideMode_t side, cublasFillMode_t uplo,
960 cublasOperation_t trans, cublasDiagType_t diag, int m, int n,
961 const Scalar* alpha, const Scalar* const host_a_dev_ptrs[], int lda,
962 Scalar* host_b_dev_ptrs[], int ldb, int batch_size) {
963 mutex_lock lock(handle_map_mutex);
964 using CudaScalar = typename CUDAComplexT<Scalar>::type;
965 ScratchSpace<uint8> dev_a_dev_ptrs =
966 cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
967 /* on_host */ false);
968 ScratchSpace<uint8> dev_b_dev_ptrs =
969 cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
970 /* on_host */ false);
971 if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
972 host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
973 return errors::Internal("TrsmBatched: failed to copy pointers to device");
974 }
975 if (!CopyHostToDevice(context, dev_b_dev_ptrs.mutable_data() /* dest */,
976 host_b_dev_ptrs /* source */, dev_b_dev_ptrs.bytes())) {
977 return errors::Internal("TrsmBatched: failed to copy pointers to device");
978 }
979 TF_RETURN_IF_CUBLAS_ERROR(
980 solver(cublas_handle, side, uplo, trans, diag, m, n,
981 reinterpret_cast<const CudaScalar*>(alpha),
982 reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()),
983 lda, reinterpret_cast<CudaScalar**>(dev_b_dev_ptrs.mutable_data()),
984 ldb, batch_size));
985 return Status::OK();
986 }
987
988 #define TRSM_BATCHED_INSTANCE(Scalar, type_prefix) \
989 template <> \
990 Status CudaSolver::TrsmBatched( \
991 cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, \
992 cublasDiagType_t diag, int m, int n, const Scalar* alpha, \
993 const Scalar* const dev_Aarray[], int lda, Scalar* dev_Barray[], \
994 int ldb, int batch_size) { \
995 return TrsmBatchedImpl(BLAS_SOLVER_FN(trsmBatched, type_prefix), this, \
996 context_, cublas_handle_, side, uplo, trans, diag, \
997 m, n, alpha, dev_Aarray, lda, dev_Barray, ldb, \
998 batch_size); \
999 }
1000
1001 TF_CALL_LAPACK_TYPES(TRSM_BATCHED_INSTANCE);
1002
1003 } // namespace tensorflow
1004
1005 #endif // GOOGLE_CUDA
1006