• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/cusolver_context.h"
17 
18 #include <algorithm>
19 #include <complex>
20 #include <utility>
21 
22 #include "tensorflow/compiler/xla/primitive_util.h"
23 #include "tensorflow/compiler/xla/util.h"
24 
25 namespace xla {
26 namespace gpu {
27 
28 namespace {
29 
30 // Type traits to get CUDA complex types from std::complex<T>.
31 template <typename T>
32 struct GpuComplexT {
33   typedef T type;
34 };
35 
36 // For ROCm, use hipsolver if the ROCm version >= 4.5 and
37 // rocblas/rocsolver if the ROCm version < 4.5.
38 
39 #if !TENSORFLOW_USE_ROCM
40 
41 #define GPU_SOLVER_CONTEXT_PREFIX cusolverDn
42 #define GPU_SOLVER_PREFIX cusolverDn
43 
44 using gpuStream_t = cudaStream_t;
45 
46 template <>
47 struct GpuComplexT<std::complex<float>> {
48   typedef cuComplex type;
49 };
50 template <>
51 struct GpuComplexT<std::complex<double>> {
52   typedef cuDoubleComplex type;
53 };
54 
55 template <>
56 struct GpuComplexT<std::complex<float>*> {
57   typedef cuComplex* type;
58 };
59 template <>
60 struct GpuComplexT<std::complex<double>*> {
61   typedef cuDoubleComplex* type;
62 };
63 
64 #else
65 
66 using gpuStream_t = hipStream_t;
67 
68 #if TF_ROCM_VERSION >= 40500
69 #define GPU_SOLVER_CONTEXT_PREFIX tensorflow::wrap::hipsolver
70 #define GPU_SOLVER_PREFIX tensorflow::wrap::hipsolver
71 
72 template <>
73 struct GpuComplexT<std::complex<float>> {
74   typedef hipFloatComplex type;
75 };
76 template <>
77 struct GpuComplexT<std::complex<double>> {
78   typedef hipDoubleComplex type;
79 };
80 
81 template <>
82 struct GpuComplexT<std::complex<float>*> {
83   typedef hipFloatComplex* type;
84 };
85 template <>
86 struct GpuComplexT<std::complex<double>*> {
87   typedef hipDoubleComplex* type;
88 };
89 #else
90 #define GPU_SOLVER_CONTEXT_PREFIX tensorflow::wrap::rocblas_
91 #define GPU_SOLVER_PREFIX tensorflow::wrap::rocsolver_
92 
93 template <>
94 struct GpuComplexT<std::complex<float>> {
95   typedef rocblas_float_complex type;
96 };
97 template <>
98 struct GpuComplexT<std::complex<double>> {
99   typedef rocblas_double_complex type;
100 };
101 
102 template <>
103 struct GpuComplexT<std::complex<float>*> {
104   typedef rocblas_float_complex* type;
105 };
106 template <>
107 struct GpuComplexT<std::complex<double>*> {
108   typedef rocblas_double_complex* type;
109 };
110 #endif  // TF_ROCM_VERSION >= 40500
111 
112 #endif  // !TENSORFLOW_USE_ROCM
113 
114 template <typename T>
ToDevicePointer(se::DeviceMemory<T> p)115 inline typename GpuComplexT<T>::type* ToDevicePointer(se::DeviceMemory<T> p) {
116   return static_cast<typename GpuComplexT<T>::type*>(p.opaque());
117 }
118 
119 #if !TENSORFLOW_USE_ROCM
GpuBlasUpperLower(se::blas::UpperLower uplo)120 cublasFillMode_t GpuBlasUpperLower(se::blas::UpperLower uplo) {
121   switch (uplo) {
122     case se::blas::UpperLower::kUpper:
123       return CUBLAS_FILL_MODE_UPPER;
124     case se::blas::UpperLower::kLower:
125       return CUBLAS_FILL_MODE_LOWER;
126     default:
127       LOG(FATAL) << "Invalid value of blas::UpperLower.";
128   }
129 }
130 
131 // Converts a cuSolver status to a Status.
ConvertStatus(cusolverStatus_t status)132 Status ConvertStatus(cusolverStatus_t status) {
133   switch (status) {
134     case CUSOLVER_STATUS_SUCCESS:
135       return OkStatus();
136     case CUSOLVER_STATUS_NOT_INITIALIZED:
137       return FailedPrecondition("cuSolver has not been initialized");
138     case CUSOLVER_STATUS_ALLOC_FAILED:
139       return ResourceExhausted("cuSolver allocation failed");
140     case CUSOLVER_STATUS_INVALID_VALUE:
141       return InvalidArgument("cuSolver invalid value error");
142     case CUSOLVER_STATUS_ARCH_MISMATCH:
143       return FailedPrecondition("cuSolver architecture mismatch error");
144     case CUSOLVER_STATUS_MAPPING_ERROR:
145       return Unknown("cuSolver mapping error");
146     case CUSOLVER_STATUS_EXECUTION_FAILED:
147       return Unknown("cuSolver execution failed");
148     case CUSOLVER_STATUS_INTERNAL_ERROR:
149       return Internal("cuSolver internal error");
150     case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
151       return Unimplemented("cuSolver matrix type not supported error");
152     case CUSOLVER_STATUS_NOT_SUPPORTED:
153       return Unimplemented("cuSolver not supported error");
154     case CUSOLVER_STATUS_ZERO_PIVOT:
155       return InvalidArgument("cuSolver zero pivot error");
156     case CUSOLVER_STATUS_INVALID_LICENSE:
157       return FailedPrecondition("cuSolver invalid license error");
158     default:
159       return Unknown("Unknown cuSolver error");
160   }
161 }
162 #else
163 #if TF_ROCM_VERSION >= 40500
GpuBlasUpperLower(se::blas::UpperLower uplo)164 hipsolverFillMode_t GpuBlasUpperLower(se::blas::UpperLower uplo) {
165   switch (uplo) {
166     case se::blas::UpperLower::kUpper:
167       return HIPSOLVER_FILL_MODE_UPPER;
168     case se::blas::UpperLower::kLower:
169       return HIPSOLVER_FILL_MODE_LOWER;
170     default:
171       LOG(FATAL) << "Invalid value of blas::UpperLower.";
172   }
173 }
174 
ConvertStatus(hipsolverStatus_t status)175 Status ConvertStatus(hipsolverStatus_t status) {
176   switch (status) {
177     case HIPSOLVER_STATUS_SUCCESS:
178       return Status::OK();
179     case HIPSOLVER_STATUS_NOT_INITIALIZED:
180       return FailedPrecondition("hipsolver has not been initialized");
181     case HIPSOLVER_STATUS_ALLOC_FAILED:
182       return ResourceExhausted("hipsolver allocation failed");
183     case HIPSOLVER_STATUS_INVALID_VALUE:
184       return InvalidArgument("hipsolver invalid value error");
185     case HIPSOLVER_STATUS_MAPPING_ERROR:
186       return Unknown("hipsolver mapping error");
187     case HIPSOLVER_STATUS_EXECUTION_FAILED:
188       return Unknown("hipsolver execution failed");
189     case HIPSOLVER_STATUS_INTERNAL_ERROR:
190       return Internal("hipsolver internal error");
191     case HIPSOLVER_STATUS_NOT_SUPPORTED:
192       return Unimplemented("hipsolver not supported error");
193     case HIPSOLVER_STATUS_ARCH_MISMATCH:
194       return FailedPrecondition("cuSolver architecture mismatch error");
195     case HIPSOLVER_STATUS_HANDLE_IS_NULLPTR:
196       return InvalidArgument("hipsolver handle is nullptr error");
197     case HIPSOLVER_STATUS_INVALID_ENUM:
198       return InvalidArgument("hipsolver invalid enum error");
199     case HIPSOLVER_STATUS_UNKNOWN:
200       return Unknown("hipsolver status unknown");
201     default:
202       return Unknown("Unknown hipsolver error");
203   }
204 }
205 #else
GpuBlasUpperLower(se::blas::UpperLower uplo)206 rocblas_fill GpuBlasUpperLower(se::blas::UpperLower uplo) {
207   switch (uplo) {
208     case se::blas::UpperLower::kUpper:
209       return rocblas_fill_upper;
210     case se::blas::UpperLower::kLower:
211       return rocblas_fill_lower;
212     default:
213       LOG(FATAL) << "Invalid value of blas::UpperLower.";
214   }
215 }
216 
ConvertStatus(rocblas_status status)217 Status ConvertStatus(rocblas_status status) {
218   switch (status) {
219     case rocblas_status_success:
220       return Status::OK();
221     case rocblas_status_invalid_handle:
222       return FailedPrecondition("handle not initialized, invalid or null");
223     case rocblas_status_not_implemented:
224       return Internal("function is not implemented");
225     case rocblas_status_invalid_pointer:
226       return InvalidArgument("invalid pointer argument");
227     case rocblas_status_invalid_size:
228       return InvalidArgument("invalid size argument");
229     case rocblas_status_memory_error:
230       return Internal("failed internal memory allocation, copy or dealloc");
231     case rocblas_status_internal_error:
232       return Internal("other internal library failure");
233     case rocblas_status_perf_degraded:
234       return Internal("performance degraded due to low device memory");
235     case rocblas_status_size_query_mismatch:
236       return Unknown("unmatched start/stop size query");
237     case rocblas_status_size_increased:
238       return Unknown("queried device memory size increased");
239     case rocblas_status_size_unchanged:
240       return Unknown("queried device memory size unchanged");
241     case rocblas_status_invalid_value:
242       return InvalidArgument("passed argument not valid");
243     case rocblas_status_continue:
244       return Unknown("nothing preventing function to proceed");
245     default:
246       return Unknown("Unknown rocsolver error");
247   }
248 }
249 #endif  // TF_ROCM_VERSION >= 40500
250 #endif  // TENSORFLOW_USE_ROCM
251 
252 #define GPU_SOLVER_CAT_NX(A, B) A##B
253 #define GPU_SOLVER_CAT(A, B) GPU_SOLVER_CAT_NX(A, B)
254 
255 #if TENSORFLOW_USE_CUSOLVER_OR_HIPSOLVER
256 #define GpuSolverCreate GPU_SOLVER_CAT(GPU_SOLVER_CONTEXT_PREFIX, Create)
257 #define GpuSolverSetStream GPU_SOLVER_CAT(GPU_SOLVER_CONTEXT_PREFIX, SetStream)
258 #define GpuSolverDestroy GPU_SOLVER_CAT(GPU_SOLVER_CONTEXT_PREFIX, Destroy)
259 #else  // TENSORFLOW_USE_ROCSOLVER
260 #define GpuSolverCreate GPU_SOLVER_CAT(GPU_SOLVER_CONTEXT_PREFIX, create_handle)
261 #define GpuSolverSetStream GPU_SOLVER_CAT(GPU_SOLVER_CONTEXT_PREFIX, set_stream)
262 #define GpuSolverDestroy \
263   GPU_SOLVER_CAT(GPU_SOLVER_CONTEXT_PREFIX, destroy_handle)
264 #endif
265 #define GpuSolverSpotrf_bufferSize \
266   GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, Spotrf_bufferSize)
267 #define GpuSolverDpotrf_bufferSize \
268   GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, Dpotrf_bufferSize)
269 #define GpuSolverCpotrf_bufferSize \
270   GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, Cpotrf_bufferSize)
271 #define GpuSolverZpotrf_bufferSize \
272   GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, Zpotrf_bufferSize)
273 #if TENSORFLOW_USE_CUSOLVER_OR_HIPSOLVER
274 #define GpuSolverSpotrf GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, Spotrf)
275 #define GpuSolverDpotrf GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, Dpotrf)
276 #define GpuSolverCpotrf GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, Cpotrf)
277 #define GpuSolverZpotrf GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, Zpotrf)
278 #define GpuSolverSpotrfBatched GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, SpotrfBatched)
279 #define GpuSolverDpotrfBatched GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, DpotrfBatched)
280 #define GpuSolverCpotrfBatched GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, CpotrfBatched)
281 #define GpuSolverZpotrfBatched GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, ZpotrfBatched)
282 #else  // TENSORFLOW_USE_ROCSOLVER
283 #define GpuSolverSpotrf GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, spotrf)
284 #define GpuSolverDpotrf GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, dpotrf)
285 #define GpuSolverCpotrf GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, cpotrf)
286 #define GpuSolverZpotrf GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, zpotrf)
287 #define GpuSolverSpotrfBatched GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, spotrf_batched)
288 #define GpuSolverDpotrfBatched GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, dpotrf_batched)
289 #define GpuSolverCpotrfBatched GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, cpotrf_batched)
290 #define GpuSolverZpotrfBatched GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, zpotrf_batched)
291 #endif
292 
293 }  // namespace
294 
Create(se::Stream * stream)295 StatusOr<GpuSolverContext> GpuSolverContext::Create(se::Stream* stream) {
296   gpusolverHandle_t handle;
297   TF_RETURN_IF_ERROR(ConvertStatus(GpuSolverCreate(&handle)));
298   GpuSolverContext context(stream, handle);
299 
300   if (stream) {
301     // StreamExecutor really should just expose the Cuda stream to clients...
302     const gpuStream_t* gpu_stream =
303         CHECK_NOTNULL(reinterpret_cast<const gpuStream_t*>(
304             stream->implementation()->GpuStreamMemberHack()));
305     TF_RETURN_IF_ERROR(ConvertStatus(GpuSolverSetStream(handle, *gpu_stream)));
306   }
307 
308   return std::move(context);
309 }
310 
GpuSolverContext(se::Stream * stream,gpusolverHandle_t handle)311 GpuSolverContext::GpuSolverContext(se::Stream* stream, gpusolverHandle_t handle)
312     : stream_(stream), handle_(handle) {}
313 
GpuSolverContext(GpuSolverContext && other)314 GpuSolverContext::GpuSolverContext(GpuSolverContext&& other) {
315   handle_ = other.handle_;
316   stream_ = other.stream_;
317   other.handle_ = nullptr;
318   other.stream_ = nullptr;
319 }
320 
operator =(GpuSolverContext && other)321 GpuSolverContext& GpuSolverContext::operator=(GpuSolverContext&& other) {
322   std::swap(handle_, other.handle_);
323   std::swap(stream_, other.stream_);
324   return *this;
325 }
326 
~GpuSolverContext()327 GpuSolverContext::~GpuSolverContext() {
328   if (handle_) {
329     Status status = ConvertStatus(GpuSolverDestroy(handle_));
330     if (!status.ok()) {
331       LOG(ERROR) << "GpuSolverDestroy failed: " << status;
332     }
333   }
334 }
335 
336 // Note: NVidia have promised that it is safe to pass 'nullptr' as the argument
337 // buffers to cuSolver buffer size methods and this will be a documented
338 // behavior in a future cuSolver release.
PotrfBufferSize(PrimitiveType type,se::blas::UpperLower uplo,int n,int lda,int batch_size)339 StatusOr<int64_t> GpuSolverContext::PotrfBufferSize(PrimitiveType type,
340                                                     se::blas::UpperLower uplo,
341                                                     int n, int lda,
342                                                     int batch_size) {
343 #if TENSORFLOW_USE_CUSOLVER_OR_HIPSOLVER
344   int size = -1;
345   switch (type) {
346     case F32: {
347       TF_RETURN_IF_ERROR(ConvertStatus(GpuSolverSpotrf_bufferSize(
348           handle(), GpuBlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
349       break;
350     }
351     case F64: {
352       TF_RETURN_IF_ERROR(ConvertStatus(GpuSolverDpotrf_bufferSize(
353           handle(), GpuBlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
354       break;
355     }
356     case C64: {
357       TF_RETURN_IF_ERROR(ConvertStatus(GpuSolverCpotrf_bufferSize(
358           handle(), GpuBlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
359       break;
360     }
361     case C128: {
362       TF_RETURN_IF_ERROR(ConvertStatus(GpuSolverZpotrf_bufferSize(
363           handle(), GpuBlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
364       break;
365     }
366     default:
367       return InvalidArgument("Invalid type for cholesky decomposition: %s",
368                              PrimitiveType_Name(type));
369   }
370   // CUDA's potrfBatched needs space for the `as` array, which contains
371   // batch_size pointers.  Divide by sizeof(type) because this function returns
372   // not bytes but a number of elements of `type`.
373   int64_t potrf_batched_scratch = CeilOfRatio<int64_t>(
374       batch_size * sizeof(void*), primitive_util::ByteWidth(type));
375 
376   return std::max<int64_t>(size, potrf_batched_scratch);
377 #else  // not supported in rocsolver
378   return 0;
379 #endif
380 }
381 
Potrf(se::blas::UpperLower uplo,int n,se::DeviceMemory<float> a,int lda,se::DeviceMemory<int> lapack_info,se::DeviceMemoryBase workspace)382 Status GpuSolverContext::Potrf(se::blas::UpperLower uplo, int n,
383                                se::DeviceMemory<float> a, int lda,
384                                se::DeviceMemory<int> lapack_info,
385                                se::DeviceMemoryBase workspace) {
386   return ConvertStatus(GpuSolverSpotrf(
387       handle(), GpuBlasUpperLower(uplo), n, ToDevicePointer(a), lda,
388 #if TENSORFLOW_USE_CUSOLVER_OR_HIPSOLVER
389       ToDevicePointer(se::DeviceMemory<float>(workspace)),
390       se::DeviceMemory<float>(workspace).ElementCount(),
391 #endif
392       ToDevicePointer(lapack_info)));
393 }
394 
Potrf(se::blas::UpperLower uplo,int n,se::DeviceMemory<double> a,int lda,se::DeviceMemory<int> lapack_info,se::DeviceMemoryBase workspace)395 Status GpuSolverContext::Potrf(se::blas::UpperLower uplo, int n,
396                                se::DeviceMemory<double> a, int lda,
397                                se::DeviceMemory<int> lapack_info,
398                                se::DeviceMemoryBase workspace) {
399   return ConvertStatus(GpuSolverDpotrf(
400       handle(), GpuBlasUpperLower(uplo), n, ToDevicePointer(a), lda,
401 #if TENSORFLOW_USE_CUSOLVER_OR_HIPSOLVER
402       ToDevicePointer(se::DeviceMemory<double>(workspace)),
403       se::DeviceMemory<double>(workspace).ElementCount(),
404 #endif
405       ToDevicePointer(lapack_info)));
406 }
407 
Potrf(se::blas::UpperLower uplo,int n,se::DeviceMemory<std::complex<float>> a,int lda,se::DeviceMemory<int> lapack_info,se::DeviceMemoryBase workspace)408 Status GpuSolverContext::Potrf(se::blas::UpperLower uplo, int n,
409                                se::DeviceMemory<std::complex<float>> a, int lda,
410                                se::DeviceMemory<int> lapack_info,
411                                se::DeviceMemoryBase workspace) {
412   return ConvertStatus(GpuSolverCpotrf(
413       handle(), GpuBlasUpperLower(uplo), n, ToDevicePointer(a), lda,
414 #if TENSORFLOW_USE_CUSOLVER_OR_HIPSOLVER
415       ToDevicePointer(se::DeviceMemory<std::complex<float>>(workspace)),
416       se::DeviceMemory<std::complex<float>>(workspace).ElementCount(),
417 #endif
418       ToDevicePointer(lapack_info)));
419 }
420 
Potrf(se::blas::UpperLower uplo,int n,se::DeviceMemory<std::complex<double>> a,int lda,se::DeviceMemory<int> lapack_info,se::DeviceMemoryBase workspace)421 Status GpuSolverContext::Potrf(se::blas::UpperLower uplo, int n,
422                                se::DeviceMemory<std::complex<double>> a,
423                                int lda, se::DeviceMemory<int> lapack_info,
424                                se::DeviceMemoryBase workspace) {
425   return ConvertStatus(GpuSolverZpotrf(
426       handle(), GpuBlasUpperLower(uplo), n, ToDevicePointer(a), lda,
427 #if TENSORFLOW_USE_CUSOLVER_OR_HIPSOLVER
428       ToDevicePointer(se::DeviceMemory<std::complex<double>>(workspace)),
429       se::DeviceMemory<std::complex<double>>(workspace).ElementCount(),
430 #endif
431       ToDevicePointer(lapack_info)));
432 }
433 
PotrfBatched(se::blas::UpperLower uplo,int n,se::DeviceMemory<float * > as,int lda,se::DeviceMemory<int> lapack_info,int batch_size)434 Status GpuSolverContext::PotrfBatched(se::blas::UpperLower uplo, int n,
435                                       se::DeviceMemory<float*> as, int lda,
436                                       se::DeviceMemory<int> lapack_info,
437                                       int batch_size) {
438   return ConvertStatus(GpuSolverSpotrfBatched(
439       handle(), GpuBlasUpperLower(uplo), n, ToDevicePointer(as), lda,
440 #if TENSORFLOW_USE_HIPSOLVER
441       nullptr, 0,
442 #endif
443       ToDevicePointer(lapack_info), batch_size));
444 }
445 
PotrfBatched(se::blas::UpperLower uplo,int n,se::DeviceMemory<double * > as,int lda,se::DeviceMemory<int> lapack_info,int batch_size)446 Status GpuSolverContext::PotrfBatched(se::blas::UpperLower uplo, int n,
447                                       se::DeviceMemory<double*> as, int lda,
448                                       se::DeviceMemory<int> lapack_info,
449                                       int batch_size) {
450   return ConvertStatus(GpuSolverDpotrfBatched(
451       handle(), GpuBlasUpperLower(uplo), n, ToDevicePointer(as), lda,
452 #if TENSORFLOW_USE_HIPSOLVER
453       nullptr, 0,
454 #endif
455       ToDevicePointer(lapack_info), batch_size));
456 }
457 
PotrfBatched(se::blas::UpperLower uplo,int n,se::DeviceMemory<std::complex<float> * > as,int lda,se::DeviceMemory<int> lapack_info,int batch_size)458 Status GpuSolverContext::PotrfBatched(se::blas::UpperLower uplo, int n,
459                                       se::DeviceMemory<std::complex<float>*> as,
460                                       int lda,
461                                       se::DeviceMemory<int> lapack_info,
462                                       int batch_size) {
463   return ConvertStatus(GpuSolverCpotrfBatched(
464       handle(), GpuBlasUpperLower(uplo), n, ToDevicePointer(as), lda,
465 #if TENSORFLOW_USE_HIPSOLVER
466       nullptr, 0,
467 #endif
468       ToDevicePointer(lapack_info), batch_size));
469 }
470 
PotrfBatched(se::blas::UpperLower uplo,int n,se::DeviceMemory<std::complex<double> * > as,int lda,se::DeviceMemory<int> lapack_info,int batch_size)471 Status GpuSolverContext::PotrfBatched(
472     se::blas::UpperLower uplo, int n,
473     se::DeviceMemory<std::complex<double>*> as, int lda,
474     se::DeviceMemory<int> lapack_info, int batch_size) {
475   return ConvertStatus(GpuSolverZpotrfBatched(
476       handle(), GpuBlasUpperLower(uplo), n, ToDevicePointer(as), lda,
477 #if TENSORFLOW_USE_HIPSOLVER
478       nullptr, 0,
479 #endif
480       ToDevicePointer(lapack_info), batch_size));
481 }
482 
483 }  // namespace gpu
484 }  // namespace xla
485