1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_CUDA_SPARSE_H_
17 #define TENSORFLOW_CORE_KERNELS_CUDA_SPARSE_H_
18
19 // This header declares the class GpuSparse, which contains wrappers of
20 // cuSparse libraries for use in TensorFlow kernels.
21
22 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
23
24 #include <functional>
25 #include <vector>
26
27 #if GOOGLE_CUDA
28
29 #include "third_party/gpus/cuda/include/cusparse.h"
30
31 using gpusparseStatus_t = cusparseStatus_t;
32 using gpusparseOperation_t = cusparseOperation_t;
33 using gpusparseMatDescr_t = cusparseMatDescr_t;
34 using gpusparseAction_t = cusparseAction_t;
35 using gpusparseHandle_t = cusparseHandle_t;
36 using gpuStream_t = cudaStream_t;
37
38 #elif TENSORFLOW_USE_ROCM
39
40 #include "rocm/include/hipsparse/hipsparse.h"
41
42 using gpusparseStatus_t = hipsparseStatus_t;
43 using gpusparseOperation_t = hipsparseOperation_t;
44 using gpusparseMatDescr_t = hipsparseMatDescr_t;
45 using gpusparseAction_t = hipsparseAction_t;
46 using gpusparseHandle_t = hipsparseHandle_t;
47 using gpuStream_t = hipStream_t;
48
49 #endif
50
51 #include "tensorflow/core/framework/op_kernel.h"
52 #include "tensorflow/core/framework/tensor.h"
53 #include "tensorflow/core/framework/tensor_types.h"
54 #include "tensorflow/core/lib/core/status.h"
55 #include "tensorflow/core/platform/stream_executor.h"
56 #include "tensorflow/core/public/version.h"
57
58 // Macro that specializes a sparse method for all 4 standard
59 // numeric types.
60 // TODO: reuse with cuda_solvers
61 #define TF_CALL_LAPACK_TYPES(m) \
62 m(float, S) m(double, D) m(std::complex<float>, C) m(std::complex<double>, Z)
63
64 namespace tensorflow {
65
ConvertGPUSparseErrorToString(const gpusparseStatus_t status)66 inline string ConvertGPUSparseErrorToString(const gpusparseStatus_t status) {
67 switch (status) {
68 #define STRINGIZE(q) #q
69 #define RETURN_IF_STATUS(err) \
70 case err: \
71 return STRINGIZE(err);
72
73 #if GOOGLE_CUDA
74
75 RETURN_IF_STATUS(CUSPARSE_STATUS_SUCCESS)
76 RETURN_IF_STATUS(CUSPARSE_STATUS_NOT_INITIALIZED)
77 RETURN_IF_STATUS(CUSPARSE_STATUS_ALLOC_FAILED)
78 RETURN_IF_STATUS(CUSPARSE_STATUS_INVALID_VALUE)
79 RETURN_IF_STATUS(CUSPARSE_STATUS_ARCH_MISMATCH)
80 RETURN_IF_STATUS(CUSPARSE_STATUS_MAPPING_ERROR)
81 RETURN_IF_STATUS(CUSPARSE_STATUS_EXECUTION_FAILED)
82 RETURN_IF_STATUS(CUSPARSE_STATUS_INTERNAL_ERROR)
83 RETURN_IF_STATUS(CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED)
84
85 default:
86 return strings::StrCat("Unknown CUSPARSE error: ",
87 static_cast<int>(status));
88 #elif TENSORFLOW_USE_ROCM
89
90 RETURN_IF_STATUS(HIPSPARSE_STATUS_SUCCESS)
91 RETURN_IF_STATUS(HIPSPARSE_STATUS_NOT_INITIALIZED)
92 RETURN_IF_STATUS(HIPSPARSE_STATUS_ALLOC_FAILED)
93 RETURN_IF_STATUS(HIPSPARSE_STATUS_INVALID_VALUE)
94 RETURN_IF_STATUS(HIPSPARSE_STATUS_ARCH_MISMATCH)
95 RETURN_IF_STATUS(HIPSPARSE_STATUS_MAPPING_ERROR)
96 RETURN_IF_STATUS(HIPSPARSE_STATUS_EXECUTION_FAILED)
97 RETURN_IF_STATUS(HIPSPARSE_STATUS_INTERNAL_ERROR)
98 RETURN_IF_STATUS(HIPSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED)
99 RETURN_IF_STATUS(HIPSPARSE_STATUS_ZERO_PIVOT)
100
101 default:
102 return strings::StrCat("Unknown hipSPARSE error: ",
103 static_cast<int>(status));
104 #endif
105
106 #undef RETURN_IF_STATUS
107 #undef STRINGIZE
108 }
109 }
110
111 #if GOOGLE_CUDA
112
113 #define TF_RETURN_IF_GPUSPARSE_ERROR(expr) \
114 do { \
115 auto status = (expr); \
116 if (TF_PREDICT_FALSE(status != CUSPARSE_STATUS_SUCCESS)) { \
117 return errors::Internal(__FILE__, ":", __LINE__, " (", TF_STR(expr), \
118 "): cuSparse call failed with status ", \
119 ConvertGPUSparseErrorToString(status)); \
120 } \
121 } while (0)
122
123 #elif TENSORFLOW_USE_ROCM
124
125 #define TF_RETURN_IF_GPUSPARSE_ERROR(expr) \
126 do { \
127 auto status = (expr); \
128 if (TF_PREDICT_FALSE(status != HIPSPARSE_STATUS_SUCCESS)) { \
129 return errors::Internal(__FILE__, ":", __LINE__, " (", TF_STR(expr), \
130 "): hipSPARSE call failed with status ", \
131 ConvertGPUSparseErrorToString(status)); \
132 } \
133 } while (0)
134
135 #endif
136
TransposeAndConjugateToGpuSparseOp(bool transpose,bool conjugate,Status * status)137 inline gpusparseOperation_t TransposeAndConjugateToGpuSparseOp(bool transpose,
138 bool conjugate,
139 Status* status) {
140 #if GOOGLE_CUDA
141 if (transpose) {
142 return conjugate ? CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE
143 : CUSPARSE_OPERATION_TRANSPOSE;
144 } else {
145 if (conjugate) {
146 DCHECK(status != nullptr);
147 *status = errors::InvalidArgument(
148 "Conjugate == True and transpose == False is not supported.");
149 }
150 return CUSPARSE_OPERATION_NON_TRANSPOSE;
151 }
152 #elif TENSORFLOW_USE_ROCM
153 if (transpose) {
154 return conjugate ? HIPSPARSE_OPERATION_CONJUGATE_TRANSPOSE
155 : HIPSPARSE_OPERATION_TRANSPOSE;
156 } else {
157 if (conjugate) {
158 DCHECK(status != nullptr);
159 *status = errors::InvalidArgument(
160 "Conjugate == True and transpose == False is not supported.");
161 }
162 return HIPSPARSE_OPERATION_NON_TRANSPOSE;
163 }
164 #endif
165 }
166
167 // The GpuSparse class provides a simplified templated API for cuSparse
168 // (http://docs.nvidia.com/cuda/cusparse/index.html).
169 // An object of this class wraps static cuSparse instances,
170 // and will launch Cuda kernels on the stream wrapped by the GPU device
171 // in the OpKernelContext provided to the constructor.
172 //
173 // Notice: All the computational member functions are asynchronous and simply
174 // launch one or more Cuda kernels on the Cuda stream wrapped by the GpuSparse
175 // object.
176
177 class GpuSparse {
178 public:
179 // This object stores a pointer to context, which must outlive it.
180 explicit GpuSparse(OpKernelContext* context);
~GpuSparse()181 virtual ~GpuSparse() {}
182
183 // This initializes the GpuSparse class if it hasn't
184 // been initialized yet. All following public methods require the
185 // class has been initialized. Can be run multiple times; all
186 // subsequent calls after the first have no effect.
187 Status Initialize(); // Move to constructor?
188
189 // ====================================================================
190 // Wrappers for cuSparse start here.
191 //
192
193 // Solves tridiagonal system of equations.
194 // Note: Cuda Toolkit 9.0+ has better-performing gtsv2 routine. gtsv will be
195 // removed in Cuda Toolkit 11.0.
196 // See: https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-gtsv
197 // Returns Status::OK() if the kernel was launched successfully.
198 template <typename Scalar>
199 Status Gtsv(int m, int n, const Scalar *dl, const Scalar *d, const Scalar *du,
200 Scalar *B, int ldb) const;
201
202 // Solves tridiagonal system of equations without pivoting.
203 // Note: Cuda Toolkit 9.0+ has better-performing gtsv2_nopivot routine.
204 // gtsv_nopivot will be removed in Cuda Toolkit 11.0.
205 // See:
206 // https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-gtsv_nopivot
207 // Returns Status::OK() if the kernel was launched successfully.
208 template <typename Scalar>
209 Status GtsvNoPivot(int m, int n, const Scalar *dl, const Scalar *d,
210 const Scalar *du, Scalar *B, int ldb) const;
211
212 // Solves a batch of tridiagonal systems of equations. Doesn't support
213 // multiple right-hand sides per each system. Doesn't do pivoting.
214 // Note: Cuda Toolkit 9.0+ has better-performing gtsv2StridedBatch routine.
215 // gtsvStridedBatch will be removed in Cuda Toolkit 11.0.
216 // See:
217 // https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-gtsvstridedbatch
218 // Returns Status::OK() if the kernel was launched successfully.
219 template <typename Scalar>
220 Status GtsvStridedBatch(int m, const Scalar *dl, const Scalar *d,
221 const Scalar *du, Scalar *x, int batchCount,
222 int batchStride) const;
223
224 // Solves tridiagonal system of equations.
225 // See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2
226 template <typename Scalar>
227 Status Gtsv2(int m, int n, const Scalar *dl, const Scalar *d,
228 const Scalar *du, Scalar *B, int ldb, void *pBuffer) const;
229
230 // Computes the size of a temporary buffer used by Gtsv2.
231 // See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2_bufferSize
232 template <typename Scalar>
233 Status Gtsv2BufferSizeExt(int m, int n, const Scalar *dl, const Scalar *d,
234 const Scalar *du, const Scalar *B, int ldb,
235 size_t *bufferSizeInBytes) const;
236
237 // Solves tridiagonal system of equations without partial pivoting.
238 // See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2_nopivot
239 template <typename Scalar>
240 Status Gtsv2NoPivot(int m, int n, const Scalar *dl, const Scalar *d,
241 const Scalar *du, Scalar *B, int ldb,
242 void *pBuffer) const;
243
244 // Computes the size of a temporary buffer used by Gtsv2NoPivot.
245 // See:
246 // https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2_nopivot_bufferSize
247 template <typename Scalar>
248 Status Gtsv2NoPivotBufferSizeExt(int m, int n, const Scalar *dl,
249 const Scalar *d, const Scalar *du,
250 const Scalar *B, int ldb,
251 size_t *bufferSizeInBytes) const;
252
253 // Solves a batch of tridiagonal systems of equations. Doesn't support
254 // multiple right-hand sides per each system. Doesn't do pivoting.
255 // See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2stridedbatch
256 template <typename Scalar>
257 Status Gtsv2StridedBatch(int m, const Scalar *dl, const Scalar *d,
258 const Scalar *du, Scalar *x, int batchCount,
259 int batchStride, void *pBuffer) const;
260
261 // Computes the size of a temporary buffer used by Gtsv2StridedBatch.
262 // See:
263 // https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2stridedbatch_bufferSize
264 template <typename Scalar>
265 Status Gtsv2StridedBatchBufferSizeExt(int m, const Scalar *dl,
266 const Scalar *d, const Scalar *du,
267 const Scalar *x, int batchCount,
268 int batchStride,
269 size_t *bufferSizeInBytes) const;
270
271 // Compresses the indices of rows or columns. It can be interpreted as a
272 // conversion from COO to CSR sparse storage format. See:
273 // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csr2coo.
274 Status Csr2coo(const int* CsrRowPtr, int nnz, int m, int* cooRowInd) const;
275
276 // Uncompresses the indices of rows or columns. It can be interpreted as a
277 // conversion from CSR to COO sparse storage format. See:
278 // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-coo2csr.
279 Status Coo2csr(const int* cooRowInd, int nnz, int m, int* csrRowPtr) const;
280
281 // Sparse-dense matrix multiplication C = alpha * op(A) * op(B) + beta * C,
282 // where A is a sparse matrix in CSR format, B and C are dense tall
283 // matrices. This routine allows transposition of matrix B, which
284 // may improve performance. See:
285 // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrmm2
286 //
287 // **NOTE** Matrices B and C are expected to be in column-major
288 // order; to make them consistent with TensorFlow they
289 // must be transposed (or the matmul op's pre/post-procesisng must take this
290 // into account).
291 //
292 // **NOTE** This is an in-place operation for data in C.
293 template <typename Scalar>
294 Status Csrmm(gpusparseOperation_t transA, gpusparseOperation_t transB, int m,
295 int n, int k, int nnz, const Scalar* alpha_host,
296 const gpusparseMatDescr_t descrA, const Scalar* csrSortedValA,
297 const int* csrSortedRowPtrA, const int* csrSortedColIndA,
298 const Scalar* B, int ldb, const Scalar* beta_host, Scalar* C,
299 int ldc) const;
300
301 // Sparse-dense vector multiplication y = alpha * op(A) * x + beta * y,
302 // where A is a sparse matrix in CSR format, x and y are dense vectors. See:
303 // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrmv_mergepath
304 //
305 // **NOTE** This is an in-place operation for data in y.
306 template <typename Scalar>
307 Status Csrmv(gpusparseOperation_t transA, int m, int n, int nnz,
308 const Scalar* alpha_host, const gpusparseMatDescr_t descrA,
309 const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
310 const int* csrSortedColIndA, const Scalar* x,
311 const Scalar* beta_host, Scalar* y) const;
312
313 // Computes sparse-sparse matrix addition of matrices
314 // stored in CSR format. This is part one: calculate nnz of the
315 // output. csrSortedRowPtrC must be preallocated on device with
316 // m + 1 entries. See:
317 // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgeam.
318 Status CsrgeamNnz(int m, int n, const gpusparseMatDescr_t descrA, int nnzA,
319 const int* csrSortedRowPtrA, const int* csrSortedColIndA,
320 const gpusparseMatDescr_t descrB, int nnzB,
321 const int* csrSortedRowPtrB, const int* csrSortedColIndB,
322 const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC,
323 int* nnzTotalDevHostPtr);
324
325 // Computes sparse - sparse matrix addition of matrices
326 // stored in CSR format. This is part two: perform sparse-sparse
327 // addition. csrValC and csrColIndC must be allocated on the device
328 // with nnzTotalDevHostPtr entries (as calculated by CsrgeamNnz). See:
329 // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgeam.
330 template <typename Scalar>
331 Status Csrgeam(int m, int n, const Scalar* alpha,
332 const gpusparseMatDescr_t descrA, int nnzA,
333 const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
334 const int* csrSortedColIndA, const Scalar* beta,
335 const gpusparseMatDescr_t descrB, int nnzB,
336 const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
337 const int* csrSortedColIndB, const gpusparseMatDescr_t descrC,
338 Scalar* csrSortedValC, int* csrSortedRowPtrC,
339 int* csrSortedColIndC);
340
341 // Computes sparse-sparse matrix multiplication of matrices
342 // stored in CSR format. This is part one: calculate nnz of the
343 // output. csrSortedRowPtrC must be preallocated on device with
344 // m + 1 entries. See:
345 // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm.
346 Status CsrgemmNnz(gpusparseOperation_t transA, gpusparseOperation_t transB,
347 int m, int k, int n, const gpusparseMatDescr_t descrA,
348 int nnzA, const int* csrSortedRowPtrA,
349 const int* csrSortedColIndA,
350 const gpusparseMatDescr_t descrB, int nnzB,
351 const int* csrSortedRowPtrB, const int* csrSortedColIndB,
352 const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC,
353 int* nnzTotalDevHostPtr);
354
355 // Computes sparse - sparse matrix matmul of matrices
356 // stored in CSR format. This is part two: perform sparse-sparse
357 // addition. csrValC and csrColIndC must be allocated on the device
358 // with nnzTotalDevHostPtr entries (as calculated by CsrgemmNnz). See:
359 // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm.
360 template <typename Scalar>
361 Status Csrgemm(gpusparseOperation_t transA, gpusparseOperation_t transB,
362 int m, int k, int n, const gpusparseMatDescr_t descrA,
363 int nnzA, const Scalar* csrSortedValA,
364 const int* csrSortedRowPtrA, const int* csrSortedColIndA,
365 const gpusparseMatDescr_t descrB, int nnzB,
366 const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
367 const int* csrSortedColIndB, const gpusparseMatDescr_t descrC,
368 Scalar* csrSortedValC, int* csrSortedRowPtrC,
369 int* csrSortedColIndC);
370
371 // In-place reordering of unsorted CSR to sorted CSR.
372 // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csru2csr
373 template <typename Scalar>
374 Status Csru2csr(int m, int n, int nnz, const gpusparseMatDescr_t descrA,
375 Scalar* csrVal, const int* csrRowPtr, int* csrColInd);
376
377 // Converts from CSR to CSC format (equivalently, transpose).
378 // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-csr2cscEx
379 template <typename Scalar>
380 Status Csr2csc(int m, int n, int nnz, const Scalar* csrVal,
381 const int* csrRowPtr, const int* csrColInd, Scalar* cscVal,
382 int* cscRowInd, int* cscColPtr,
383 const gpusparseAction_t copyValues);
384
385 private:
386 bool initialized_;
387 OpKernelContext *context_; // not owned.
388 gpuStream_t gpu_stream_;
389 gpusparseHandle_t* gpusparse_handle_; // not owned.
390
391 TF_DISALLOW_COPY_AND_ASSIGN(GpuSparse);
392 };
393
394 // A wrapper class to ensure that a CUDA sparse matrix descriptor is initialized
395 // only once. For more details on the descriptor (gpusparseMatDescr_t), see:
396 // https://docs.nvidia.com/cuda/cusparse/index.html#cusparsematdescrt
397 class GpuSparseMatrixDescriptor {
398 public:
GpuSparseMatrixDescriptor()399 explicit GpuSparseMatrixDescriptor() : initialized_(false) {}
400
GpuSparseMatrixDescriptor(GpuSparseMatrixDescriptor && rhs)401 GpuSparseMatrixDescriptor(GpuSparseMatrixDescriptor&& rhs)
402 : initialized_(rhs.initialized_), descr_(std::move(rhs.descr_)) {
403 rhs.initialized_ = false;
404 }
405
406 GpuSparseMatrixDescriptor& operator=(GpuSparseMatrixDescriptor&& rhs) {
407 if (this == &rhs) return *this;
408 Release();
409 initialized_ = rhs.initialized_;
410 descr_ = std::move(rhs.descr_);
411 rhs.initialized_ = false;
412 return *this;
413 }
414
~GpuSparseMatrixDescriptor()415 ~GpuSparseMatrixDescriptor() { Release(); }
416
417 // Initializes the underlying descriptor. Will fail on the second call if
418 // called more than once.
Initialize()419 Status Initialize() {
420 DCHECK(!initialized_);
421 #if GOOGLE_CUDA
422 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descr_));
423 #elif TENSORFLOW_USE_ROCM
424 TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descr_));
425 #endif
426 initialized_ = true;
427 return Status::OK();
428 }
429
descr()430 gpusparseMatDescr_t& descr() {
431 DCHECK(initialized_);
432 return descr_;
433 }
434
descr()435 const gpusparseMatDescr_t& descr() const {
436 DCHECK(initialized_);
437 return descr_;
438 }
439
440 private:
Release()441 void Release() {
442 if (initialized_) {
443 #if GOOGLE_CUDA
444 cusparseDestroyMatDescr(descr_);
445 #elif TENSORFLOW_USE_ROCM
446 hipsparseDestroyMatDescr(descr_);
447 #endif
448 initialized_ = false;
449 }
450 }
451
452 bool initialized_;
453 gpusparseMatDescr_t descr_;
454
455 TF_DISALLOW_COPY_AND_ASSIGN(GpuSparseMatrixDescriptor);
456 };
457
458 #if GOOGLE_CUDA
459
460 // A wrapper class to ensure that an unsorted/sorted CSR conversion information
461 // struct (csru2csrInfo_t) is initialized only once. See:
462 // https://docs.nvidia.com/cuda/cusparse/index.html#csru2csr
463 class GpuSparseCsrSortingConversionInfo {
464 public:
GpuSparseCsrSortingConversionInfo()465 explicit GpuSparseCsrSortingConversionInfo() : initialized_(false) {}
466
GpuSparseCsrSortingConversionInfo(GpuSparseCsrSortingConversionInfo && rhs)467 GpuSparseCsrSortingConversionInfo(GpuSparseCsrSortingConversionInfo&& rhs)
468 : initialized_(rhs.initialized_), info_(std::move(rhs.info_)) {
469 rhs.initialized_ = false;
470 }
471
472 GpuSparseCsrSortingConversionInfo& operator=(
473 GpuSparseCsrSortingConversionInfo&& rhs) {
474 if (this == &rhs) return *this;
475 Release();
476 initialized_ = rhs.initialized_;
477 info_ = std::move(rhs.info_);
478 rhs.initialized_ = false;
479 return *this;
480 }
481
~GpuSparseCsrSortingConversionInfo()482 ~GpuSparseCsrSortingConversionInfo() { Release(); }
483
484 // Initializes the underlying info. Will fail on the second call if called
485 // more than once.
Initialize()486 Status Initialize() {
487 DCHECK(!initialized_);
488 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsru2csrInfo(&info_));
489 initialized_ = true;
490 return Status::OK();
491 }
492
info()493 csru2csrInfo_t& info() {
494 DCHECK(initialized_);
495 return info_;
496 }
497
info()498 const csru2csrInfo_t& info() const {
499 DCHECK(initialized_);
500 return info_;
501 }
502
503 private:
Release()504 void Release() {
505 if (initialized_) {
506 cusparseDestroyCsru2csrInfo(info_);
507 initialized_ = false;
508 }
509 }
510
511 bool initialized_;
512 csru2csrInfo_t info_;
513
514 TF_DISALLOW_COPY_AND_ASSIGN(GpuSparseCsrSortingConversionInfo);
515 };
516
517 #endif // GOOGLE_CUDA
518
519 } // namespace tensorflow
520
521 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
522
523 #endif // TENSORFLOW_CORE_KERNELS_CUDA_SPARSE_H_
524