1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/cusolver_context.h"
17
18 #include <algorithm>
19 #include <complex>
20 #include <utility>
21
22 #include "tensorflow/compiler/xla/primitive_util.h"
23 #include "tensorflow/compiler/xla/util.h"
24
25 namespace xla {
26 namespace gpu {
27
28 namespace {
29
30 // Type traits to get CUDA complex types from std::complex<T>.
31 template <typename T>
32 struct GpuComplexT {
33 typedef T type;
34 };
35
36 // For ROCm, use hipsolver if the ROCm version >= 4.5 and
37 // rocblas/rocsolver if the ROCm version < 4.5.
38
39 #if !TENSORFLOW_USE_ROCM
40
41 #define GPU_SOLVER_CONTEXT_PREFIX cusolverDn
42 #define GPU_SOLVER_PREFIX cusolverDn
43
44 using gpuStream_t = cudaStream_t;
45
46 template <>
47 struct GpuComplexT<std::complex<float>> {
48 typedef cuComplex type;
49 };
50 template <>
51 struct GpuComplexT<std::complex<double>> {
52 typedef cuDoubleComplex type;
53 };
54
55 template <>
56 struct GpuComplexT<std::complex<float>*> {
57 typedef cuComplex* type;
58 };
59 template <>
60 struct GpuComplexT<std::complex<double>*> {
61 typedef cuDoubleComplex* type;
62 };
63
64 #else
65
66 using gpuStream_t = hipStream_t;
67
68 #if TF_ROCM_VERSION >= 40500
69 #define GPU_SOLVER_CONTEXT_PREFIX tensorflow::wrap::hipsolver
70 #define GPU_SOLVER_PREFIX tensorflow::wrap::hipsolver
71
72 template <>
73 struct GpuComplexT<std::complex<float>> {
74 typedef hipFloatComplex type;
75 };
76 template <>
77 struct GpuComplexT<std::complex<double>> {
78 typedef hipDoubleComplex type;
79 };
80
81 template <>
82 struct GpuComplexT<std::complex<float>*> {
83 typedef hipFloatComplex* type;
84 };
85 template <>
86 struct GpuComplexT<std::complex<double>*> {
87 typedef hipDoubleComplex* type;
88 };
89 #else
90 #define GPU_SOLVER_CONTEXT_PREFIX tensorflow::wrap::rocblas_
91 #define GPU_SOLVER_PREFIX tensorflow::wrap::rocsolver_
92
93 template <>
94 struct GpuComplexT<std::complex<float>> {
95 typedef rocblas_float_complex type;
96 };
97 template <>
98 struct GpuComplexT<std::complex<double>> {
99 typedef rocblas_double_complex type;
100 };
101
102 template <>
103 struct GpuComplexT<std::complex<float>*> {
104 typedef rocblas_float_complex* type;
105 };
106 template <>
107 struct GpuComplexT<std::complex<double>*> {
108 typedef rocblas_double_complex* type;
109 };
110 #endif // TF_ROCM_VERSION >= 40500
111
112 #endif // !TENSORFLOW_USE_ROCM
113
114 template <typename T>
ToDevicePointer(se::DeviceMemory<T> p)115 inline typename GpuComplexT<T>::type* ToDevicePointer(se::DeviceMemory<T> p) {
116 return static_cast<typename GpuComplexT<T>::type*>(p.opaque());
117 }
118
119 #if !TENSORFLOW_USE_ROCM
GpuBlasUpperLower(se::blas::UpperLower uplo)120 cublasFillMode_t GpuBlasUpperLower(se::blas::UpperLower uplo) {
121 switch (uplo) {
122 case se::blas::UpperLower::kUpper:
123 return CUBLAS_FILL_MODE_UPPER;
124 case se::blas::UpperLower::kLower:
125 return CUBLAS_FILL_MODE_LOWER;
126 default:
127 LOG(FATAL) << "Invalid value of blas::UpperLower.";
128 }
129 }
130
131 // Converts a cuSolver status to a Status.
ConvertStatus(cusolverStatus_t status)132 Status ConvertStatus(cusolverStatus_t status) {
133 switch (status) {
134 case CUSOLVER_STATUS_SUCCESS:
135 return OkStatus();
136 case CUSOLVER_STATUS_NOT_INITIALIZED:
137 return FailedPrecondition("cuSolver has not been initialized");
138 case CUSOLVER_STATUS_ALLOC_FAILED:
139 return ResourceExhausted("cuSolver allocation failed");
140 case CUSOLVER_STATUS_INVALID_VALUE:
141 return InvalidArgument("cuSolver invalid value error");
142 case CUSOLVER_STATUS_ARCH_MISMATCH:
143 return FailedPrecondition("cuSolver architecture mismatch error");
144 case CUSOLVER_STATUS_MAPPING_ERROR:
145 return Unknown("cuSolver mapping error");
146 case CUSOLVER_STATUS_EXECUTION_FAILED:
147 return Unknown("cuSolver execution failed");
148 case CUSOLVER_STATUS_INTERNAL_ERROR:
149 return Internal("cuSolver internal error");
150 case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
151 return Unimplemented("cuSolver matrix type not supported error");
152 case CUSOLVER_STATUS_NOT_SUPPORTED:
153 return Unimplemented("cuSolver not supported error");
154 case CUSOLVER_STATUS_ZERO_PIVOT:
155 return InvalidArgument("cuSolver zero pivot error");
156 case CUSOLVER_STATUS_INVALID_LICENSE:
157 return FailedPrecondition("cuSolver invalid license error");
158 default:
159 return Unknown("Unknown cuSolver error");
160 }
161 }
162 #else
163 #if TF_ROCM_VERSION >= 40500
GpuBlasUpperLower(se::blas::UpperLower uplo)164 hipsolverFillMode_t GpuBlasUpperLower(se::blas::UpperLower uplo) {
165 switch (uplo) {
166 case se::blas::UpperLower::kUpper:
167 return HIPSOLVER_FILL_MODE_UPPER;
168 case se::blas::UpperLower::kLower:
169 return HIPSOLVER_FILL_MODE_LOWER;
170 default:
171 LOG(FATAL) << "Invalid value of blas::UpperLower.";
172 }
173 }
174
ConvertStatus(hipsolverStatus_t status)175 Status ConvertStatus(hipsolverStatus_t status) {
176 switch (status) {
177 case HIPSOLVER_STATUS_SUCCESS:
178 return Status::OK();
179 case HIPSOLVER_STATUS_NOT_INITIALIZED:
180 return FailedPrecondition("hipsolver has not been initialized");
181 case HIPSOLVER_STATUS_ALLOC_FAILED:
182 return ResourceExhausted("hipsolver allocation failed");
183 case HIPSOLVER_STATUS_INVALID_VALUE:
184 return InvalidArgument("hipsolver invalid value error");
185 case HIPSOLVER_STATUS_MAPPING_ERROR:
186 return Unknown("hipsolver mapping error");
187 case HIPSOLVER_STATUS_EXECUTION_FAILED:
188 return Unknown("hipsolver execution failed");
189 case HIPSOLVER_STATUS_INTERNAL_ERROR:
190 return Internal("hipsolver internal error");
191 case HIPSOLVER_STATUS_NOT_SUPPORTED:
192 return Unimplemented("hipsolver not supported error");
193 case HIPSOLVER_STATUS_ARCH_MISMATCH:
194 return FailedPrecondition("cuSolver architecture mismatch error");
195 case HIPSOLVER_STATUS_HANDLE_IS_NULLPTR:
196 return InvalidArgument("hipsolver handle is nullptr error");
197 case HIPSOLVER_STATUS_INVALID_ENUM:
198 return InvalidArgument("hipsolver invalid enum error");
199 case HIPSOLVER_STATUS_UNKNOWN:
200 return Unknown("hipsolver status unknown");
201 default:
202 return Unknown("Unknown hipsolver error");
203 }
204 }
205 #else
GpuBlasUpperLower(se::blas::UpperLower uplo)206 rocblas_fill GpuBlasUpperLower(se::blas::UpperLower uplo) {
207 switch (uplo) {
208 case se::blas::UpperLower::kUpper:
209 return rocblas_fill_upper;
210 case se::blas::UpperLower::kLower:
211 return rocblas_fill_lower;
212 default:
213 LOG(FATAL) << "Invalid value of blas::UpperLower.";
214 }
215 }
216
ConvertStatus(rocblas_status status)217 Status ConvertStatus(rocblas_status status) {
218 switch (status) {
219 case rocblas_status_success:
220 return Status::OK();
221 case rocblas_status_invalid_handle:
222 return FailedPrecondition("handle not initialized, invalid or null");
223 case rocblas_status_not_implemented:
224 return Internal("function is not implemented");
225 case rocblas_status_invalid_pointer:
226 return InvalidArgument("invalid pointer argument");
227 case rocblas_status_invalid_size:
228 return InvalidArgument("invalid size argument");
229 case rocblas_status_memory_error:
230 return Internal("failed internal memory allocation, copy or dealloc");
231 case rocblas_status_internal_error:
232 return Internal("other internal library failure");
233 case rocblas_status_perf_degraded:
234 return Internal("performance degraded due to low device memory");
235 case rocblas_status_size_query_mismatch:
236 return Unknown("unmatched start/stop size query");
237 case rocblas_status_size_increased:
238 return Unknown("queried device memory size increased");
239 case rocblas_status_size_unchanged:
240 return Unknown("queried device memory size unchanged");
241 case rocblas_status_invalid_value:
242 return InvalidArgument("passed argument not valid");
243 case rocblas_status_continue:
244 return Unknown("nothing preventing function to proceed");
245 default:
246 return Unknown("Unknown rocsolver error");
247 }
248 }
249 #endif // TF_ROCM_VERSION >= 40500
250 #endif // TENSORFLOW_USE_ROCM
251
252 #define GPU_SOLVER_CAT_NX(A, B) A##B
253 #define GPU_SOLVER_CAT(A, B) GPU_SOLVER_CAT_NX(A, B)
254
255 #if TENSORFLOW_USE_CUSOLVER_OR_HIPSOLVER
256 #define GpuSolverCreate GPU_SOLVER_CAT(GPU_SOLVER_CONTEXT_PREFIX, Create)
257 #define GpuSolverSetStream GPU_SOLVER_CAT(GPU_SOLVER_CONTEXT_PREFIX, SetStream)
258 #define GpuSolverDestroy GPU_SOLVER_CAT(GPU_SOLVER_CONTEXT_PREFIX, Destroy)
259 #else // TENSORFLOW_USE_ROCSOLVER
260 #define GpuSolverCreate GPU_SOLVER_CAT(GPU_SOLVER_CONTEXT_PREFIX, create_handle)
261 #define GpuSolverSetStream GPU_SOLVER_CAT(GPU_SOLVER_CONTEXT_PREFIX, set_stream)
262 #define GpuSolverDestroy \
263 GPU_SOLVER_CAT(GPU_SOLVER_CONTEXT_PREFIX, destroy_handle)
264 #endif
265 #define GpuSolverSpotrf_bufferSize \
266 GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, Spotrf_bufferSize)
267 #define GpuSolverDpotrf_bufferSize \
268 GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, Dpotrf_bufferSize)
269 #define GpuSolverCpotrf_bufferSize \
270 GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, Cpotrf_bufferSize)
271 #define GpuSolverZpotrf_bufferSize \
272 GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, Zpotrf_bufferSize)
273 #if TENSORFLOW_USE_CUSOLVER_OR_HIPSOLVER
274 #define GpuSolverSpotrf GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, Spotrf)
275 #define GpuSolverDpotrf GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, Dpotrf)
276 #define GpuSolverCpotrf GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, Cpotrf)
277 #define GpuSolverZpotrf GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, Zpotrf)
278 #define GpuSolverSpotrfBatched GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, SpotrfBatched)
279 #define GpuSolverDpotrfBatched GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, DpotrfBatched)
280 #define GpuSolverCpotrfBatched GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, CpotrfBatched)
281 #define GpuSolverZpotrfBatched GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, ZpotrfBatched)
282 #else // TENSORFLOW_USE_ROCSOLVER
283 #define GpuSolverSpotrf GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, spotrf)
284 #define GpuSolverDpotrf GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, dpotrf)
285 #define GpuSolverCpotrf GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, cpotrf)
286 #define GpuSolverZpotrf GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, zpotrf)
287 #define GpuSolverSpotrfBatched GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, spotrf_batched)
288 #define GpuSolverDpotrfBatched GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, dpotrf_batched)
289 #define GpuSolverCpotrfBatched GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, cpotrf_batched)
290 #define GpuSolverZpotrfBatched GPU_SOLVER_CAT(GPU_SOLVER_PREFIX, zpotrf_batched)
291 #endif
292
293 } // namespace
294
Create(se::Stream * stream)295 StatusOr<GpuSolverContext> GpuSolverContext::Create(se::Stream* stream) {
296 gpusolverHandle_t handle;
297 TF_RETURN_IF_ERROR(ConvertStatus(GpuSolverCreate(&handle)));
298 GpuSolverContext context(stream, handle);
299
300 if (stream) {
301 // StreamExecutor really should just expose the Cuda stream to clients...
302 const gpuStream_t* gpu_stream =
303 CHECK_NOTNULL(reinterpret_cast<const gpuStream_t*>(
304 stream->implementation()->GpuStreamMemberHack()));
305 TF_RETURN_IF_ERROR(ConvertStatus(GpuSolverSetStream(handle, *gpu_stream)));
306 }
307
308 return std::move(context);
309 }
310
GpuSolverContext(se::Stream * stream,gpusolverHandle_t handle)311 GpuSolverContext::GpuSolverContext(se::Stream* stream, gpusolverHandle_t handle)
312 : stream_(stream), handle_(handle) {}
313
GpuSolverContext(GpuSolverContext && other)314 GpuSolverContext::GpuSolverContext(GpuSolverContext&& other) {
315 handle_ = other.handle_;
316 stream_ = other.stream_;
317 other.handle_ = nullptr;
318 other.stream_ = nullptr;
319 }
320
operator =(GpuSolverContext && other)321 GpuSolverContext& GpuSolverContext::operator=(GpuSolverContext&& other) {
322 std::swap(handle_, other.handle_);
323 std::swap(stream_, other.stream_);
324 return *this;
325 }
326
~GpuSolverContext()327 GpuSolverContext::~GpuSolverContext() {
328 if (handle_) {
329 Status status = ConvertStatus(GpuSolverDestroy(handle_));
330 if (!status.ok()) {
331 LOG(ERROR) << "GpuSolverDestroy failed: " << status;
332 }
333 }
334 }
335
336 // Note: NVidia have promised that it is safe to pass 'nullptr' as the argument
337 // buffers to cuSolver buffer size methods and this will be a documented
338 // behavior in a future cuSolver release.
PotrfBufferSize(PrimitiveType type,se::blas::UpperLower uplo,int n,int lda,int batch_size)339 StatusOr<int64_t> GpuSolverContext::PotrfBufferSize(PrimitiveType type,
340 se::blas::UpperLower uplo,
341 int n, int lda,
342 int batch_size) {
343 #if TENSORFLOW_USE_CUSOLVER_OR_HIPSOLVER
344 int size = -1;
345 switch (type) {
346 case F32: {
347 TF_RETURN_IF_ERROR(ConvertStatus(GpuSolverSpotrf_bufferSize(
348 handle(), GpuBlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
349 break;
350 }
351 case F64: {
352 TF_RETURN_IF_ERROR(ConvertStatus(GpuSolverDpotrf_bufferSize(
353 handle(), GpuBlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
354 break;
355 }
356 case C64: {
357 TF_RETURN_IF_ERROR(ConvertStatus(GpuSolverCpotrf_bufferSize(
358 handle(), GpuBlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
359 break;
360 }
361 case C128: {
362 TF_RETURN_IF_ERROR(ConvertStatus(GpuSolverZpotrf_bufferSize(
363 handle(), GpuBlasUpperLower(uplo), n, /*A=*/nullptr, lda, &size)));
364 break;
365 }
366 default:
367 return InvalidArgument("Invalid type for cholesky decomposition: %s",
368 PrimitiveType_Name(type));
369 }
370 // CUDA's potrfBatched needs space for the `as` array, which contains
371 // batch_size pointers. Divide by sizeof(type) because this function returns
372 // not bytes but a number of elements of `type`.
373 int64_t potrf_batched_scratch = CeilOfRatio<int64_t>(
374 batch_size * sizeof(void*), primitive_util::ByteWidth(type));
375
376 return std::max<int64_t>(size, potrf_batched_scratch);
377 #else // not supported in rocsolver
378 return 0;
379 #endif
380 }
381
Potrf(se::blas::UpperLower uplo,int n,se::DeviceMemory<float> a,int lda,se::DeviceMemory<int> lapack_info,se::DeviceMemoryBase workspace)382 Status GpuSolverContext::Potrf(se::blas::UpperLower uplo, int n,
383 se::DeviceMemory<float> a, int lda,
384 se::DeviceMemory<int> lapack_info,
385 se::DeviceMemoryBase workspace) {
386 return ConvertStatus(GpuSolverSpotrf(
387 handle(), GpuBlasUpperLower(uplo), n, ToDevicePointer(a), lda,
388 #if TENSORFLOW_USE_CUSOLVER_OR_HIPSOLVER
389 ToDevicePointer(se::DeviceMemory<float>(workspace)),
390 se::DeviceMemory<float>(workspace).ElementCount(),
391 #endif
392 ToDevicePointer(lapack_info)));
393 }
394
Potrf(se::blas::UpperLower uplo,int n,se::DeviceMemory<double> a,int lda,se::DeviceMemory<int> lapack_info,se::DeviceMemoryBase workspace)395 Status GpuSolverContext::Potrf(se::blas::UpperLower uplo, int n,
396 se::DeviceMemory<double> a, int lda,
397 se::DeviceMemory<int> lapack_info,
398 se::DeviceMemoryBase workspace) {
399 return ConvertStatus(GpuSolverDpotrf(
400 handle(), GpuBlasUpperLower(uplo), n, ToDevicePointer(a), lda,
401 #if TENSORFLOW_USE_CUSOLVER_OR_HIPSOLVER
402 ToDevicePointer(se::DeviceMemory<double>(workspace)),
403 se::DeviceMemory<double>(workspace).ElementCount(),
404 #endif
405 ToDevicePointer(lapack_info)));
406 }
407
Potrf(se::blas::UpperLower uplo,int n,se::DeviceMemory<std::complex<float>> a,int lda,se::DeviceMemory<int> lapack_info,se::DeviceMemoryBase workspace)408 Status GpuSolverContext::Potrf(se::blas::UpperLower uplo, int n,
409 se::DeviceMemory<std::complex<float>> a, int lda,
410 se::DeviceMemory<int> lapack_info,
411 se::DeviceMemoryBase workspace) {
412 return ConvertStatus(GpuSolverCpotrf(
413 handle(), GpuBlasUpperLower(uplo), n, ToDevicePointer(a), lda,
414 #if TENSORFLOW_USE_CUSOLVER_OR_HIPSOLVER
415 ToDevicePointer(se::DeviceMemory<std::complex<float>>(workspace)),
416 se::DeviceMemory<std::complex<float>>(workspace).ElementCount(),
417 #endif
418 ToDevicePointer(lapack_info)));
419 }
420
Potrf(se::blas::UpperLower uplo,int n,se::DeviceMemory<std::complex<double>> a,int lda,se::DeviceMemory<int> lapack_info,se::DeviceMemoryBase workspace)421 Status GpuSolverContext::Potrf(se::blas::UpperLower uplo, int n,
422 se::DeviceMemory<std::complex<double>> a,
423 int lda, se::DeviceMemory<int> lapack_info,
424 se::DeviceMemoryBase workspace) {
425 return ConvertStatus(GpuSolverZpotrf(
426 handle(), GpuBlasUpperLower(uplo), n, ToDevicePointer(a), lda,
427 #if TENSORFLOW_USE_CUSOLVER_OR_HIPSOLVER
428 ToDevicePointer(se::DeviceMemory<std::complex<double>>(workspace)),
429 se::DeviceMemory<std::complex<double>>(workspace).ElementCount(),
430 #endif
431 ToDevicePointer(lapack_info)));
432 }
433
PotrfBatched(se::blas::UpperLower uplo,int n,se::DeviceMemory<float * > as,int lda,se::DeviceMemory<int> lapack_info,int batch_size)434 Status GpuSolverContext::PotrfBatched(se::blas::UpperLower uplo, int n,
435 se::DeviceMemory<float*> as, int lda,
436 se::DeviceMemory<int> lapack_info,
437 int batch_size) {
438 return ConvertStatus(GpuSolverSpotrfBatched(
439 handle(), GpuBlasUpperLower(uplo), n, ToDevicePointer(as), lda,
440 #if TENSORFLOW_USE_HIPSOLVER
441 nullptr, 0,
442 #endif
443 ToDevicePointer(lapack_info), batch_size));
444 }
445
PotrfBatched(se::blas::UpperLower uplo,int n,se::DeviceMemory<double * > as,int lda,se::DeviceMemory<int> lapack_info,int batch_size)446 Status GpuSolverContext::PotrfBatched(se::blas::UpperLower uplo, int n,
447 se::DeviceMemory<double*> as, int lda,
448 se::DeviceMemory<int> lapack_info,
449 int batch_size) {
450 return ConvertStatus(GpuSolverDpotrfBatched(
451 handle(), GpuBlasUpperLower(uplo), n, ToDevicePointer(as), lda,
452 #if TENSORFLOW_USE_HIPSOLVER
453 nullptr, 0,
454 #endif
455 ToDevicePointer(lapack_info), batch_size));
456 }
457
PotrfBatched(se::blas::UpperLower uplo,int n,se::DeviceMemory<std::complex<float> * > as,int lda,se::DeviceMemory<int> lapack_info,int batch_size)458 Status GpuSolverContext::PotrfBatched(se::blas::UpperLower uplo, int n,
459 se::DeviceMemory<std::complex<float>*> as,
460 int lda,
461 se::DeviceMemory<int> lapack_info,
462 int batch_size) {
463 return ConvertStatus(GpuSolverCpotrfBatched(
464 handle(), GpuBlasUpperLower(uplo), n, ToDevicePointer(as), lda,
465 #if TENSORFLOW_USE_HIPSOLVER
466 nullptr, 0,
467 #endif
468 ToDevicePointer(lapack_info), batch_size));
469 }
470
PotrfBatched(se::blas::UpperLower uplo,int n,se::DeviceMemory<std::complex<double> * > as,int lda,se::DeviceMemory<int> lapack_info,int batch_size)471 Status GpuSolverContext::PotrfBatched(
472 se::blas::UpperLower uplo, int n,
473 se::DeviceMemory<std::complex<double>*> as, int lda,
474 se::DeviceMemory<int> lapack_info, int batch_size) {
475 return ConvertStatus(GpuSolverZpotrfBatched(
476 handle(), GpuBlasUpperLower(uplo), n, ToDevicePointer(as), lda,
477 #if TENSORFLOW_USE_HIPSOLVER
478 nullptr, 0,
479 #endif
480 ToDevicePointer(lapack_info), batch_size));
481 }
482
483 } // namespace gpu
484 } // namespace xla
485