• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_CUDA_SPARSE_H_
17 #define TENSORFLOW_CORE_KERNELS_CUDA_SPARSE_H_
18 
19 // This header declares the class GpuSparse, which contains wrappers of
20 // cuSparse libraries for use in TensorFlow kernels.
21 
22 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
23 
24 #include <functional>
25 #include <vector>
26 
27 #if GOOGLE_CUDA
28 
29 #include "third_party/gpus/cuda/include/cusparse.h"
30 
31 using gpusparseStatus_t = cusparseStatus_t;
32 using gpusparseOperation_t = cusparseOperation_t;
33 using gpusparseMatDescr_t = cusparseMatDescr_t;
34 using gpusparseAction_t = cusparseAction_t;
35 using gpusparseHandle_t = cusparseHandle_t;
36 using gpuStream_t = cudaStream_t;
37 
38 #elif TENSORFLOW_USE_ROCM
39 
40 #include "rocm/include/hipsparse/hipsparse.h"
41 
42 using gpusparseStatus_t = hipsparseStatus_t;
43 using gpusparseOperation_t = hipsparseOperation_t;
44 using gpusparseMatDescr_t = hipsparseMatDescr_t;
45 using gpusparseAction_t = hipsparseAction_t;
46 using gpusparseHandle_t = hipsparseHandle_t;
47 using gpuStream_t = hipStream_t;
48 
49 #endif
50 
51 #include "tensorflow/core/framework/op_kernel.h"
52 #include "tensorflow/core/framework/tensor.h"
53 #include "tensorflow/core/framework/tensor_types.h"
54 #include "tensorflow/core/lib/core/status.h"
55 #include "tensorflow/core/platform/stream_executor.h"
56 #include "tensorflow/core/public/version.h"
57 
58 // Macro that specializes a sparse method for all 4 standard
59 // numeric types.
60 // TODO: reuse with cuda_solvers
61 #define TF_CALL_LAPACK_TYPES(m) \
62   m(float, S) m(double, D) m(std::complex<float>, C) m(std::complex<double>, Z)
63 
64 namespace tensorflow {
65 
ConvertGPUSparseErrorToString(const gpusparseStatus_t status)66 inline string ConvertGPUSparseErrorToString(const gpusparseStatus_t status) {
67   switch (status) {
68 #define STRINGIZE(q) #q
69 #define RETURN_IF_STATUS(err) \
70   case err:                   \
71     return STRINGIZE(err);
72 
73 #if GOOGLE_CUDA
74 
75     RETURN_IF_STATUS(CUSPARSE_STATUS_SUCCESS)
76     RETURN_IF_STATUS(CUSPARSE_STATUS_NOT_INITIALIZED)
77     RETURN_IF_STATUS(CUSPARSE_STATUS_ALLOC_FAILED)
78     RETURN_IF_STATUS(CUSPARSE_STATUS_INVALID_VALUE)
79     RETURN_IF_STATUS(CUSPARSE_STATUS_ARCH_MISMATCH)
80     RETURN_IF_STATUS(CUSPARSE_STATUS_MAPPING_ERROR)
81     RETURN_IF_STATUS(CUSPARSE_STATUS_EXECUTION_FAILED)
82     RETURN_IF_STATUS(CUSPARSE_STATUS_INTERNAL_ERROR)
83     RETURN_IF_STATUS(CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED)
84 
85     default:
86       return strings::StrCat("Unknown CUSPARSE error: ",
87                              static_cast<int>(status));
88 #elif TENSORFLOW_USE_ROCM
89 
90     RETURN_IF_STATUS(HIPSPARSE_STATUS_SUCCESS)
91     RETURN_IF_STATUS(HIPSPARSE_STATUS_NOT_INITIALIZED)
92     RETURN_IF_STATUS(HIPSPARSE_STATUS_ALLOC_FAILED)
93     RETURN_IF_STATUS(HIPSPARSE_STATUS_INVALID_VALUE)
94     RETURN_IF_STATUS(HIPSPARSE_STATUS_ARCH_MISMATCH)
95     RETURN_IF_STATUS(HIPSPARSE_STATUS_MAPPING_ERROR)
96     RETURN_IF_STATUS(HIPSPARSE_STATUS_EXECUTION_FAILED)
97     RETURN_IF_STATUS(HIPSPARSE_STATUS_INTERNAL_ERROR)
98     RETURN_IF_STATUS(HIPSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED)
99     RETURN_IF_STATUS(HIPSPARSE_STATUS_ZERO_PIVOT)
100 
101     default:
102       return strings::StrCat("Unknown hipSPARSE error: ",
103                              static_cast<int>(status));
104 #endif
105 
106 #undef RETURN_IF_STATUS
107 #undef STRINGIZE
108   }
109 }
110 
111 #if GOOGLE_CUDA
112 
113 #define TF_RETURN_IF_GPUSPARSE_ERROR(expr)                                 \
114   do {                                                                     \
115     auto status = (expr);                                                  \
116     if (TF_PREDICT_FALSE(status != CUSPARSE_STATUS_SUCCESS)) {             \
117       return errors::Internal(__FILE__, ":", __LINE__, " (", TF_STR(expr), \
118                               "): cuSparse call failed with status ",      \
119                               ConvertGPUSparseErrorToString(status));      \
120     }                                                                      \
121   } while (0)
122 
123 #elif TENSORFLOW_USE_ROCM
124 
125 #define TF_RETURN_IF_GPUSPARSE_ERROR(expr)                                 \
126   do {                                                                     \
127     auto status = (expr);                                                  \
128     if (TF_PREDICT_FALSE(status != HIPSPARSE_STATUS_SUCCESS)) {            \
129       return errors::Internal(__FILE__, ":", __LINE__, " (", TF_STR(expr), \
130                               "): hipSPARSE call failed with status ",     \
131                               ConvertGPUSparseErrorToString(status));      \
132     }                                                                      \
133   } while (0)
134 
135 #endif
136 
TransposeAndConjugateToGpuSparseOp(bool transpose,bool conjugate,Status * status)137 inline gpusparseOperation_t TransposeAndConjugateToGpuSparseOp(bool transpose,
138                                                                bool conjugate,
139                                                                Status* status) {
140 #if GOOGLE_CUDA
141   if (transpose) {
142     return conjugate ? CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE
143                      : CUSPARSE_OPERATION_TRANSPOSE;
144   } else {
145     if (conjugate) {
146       DCHECK(status != nullptr);
147       *status = errors::InvalidArgument(
148           "Conjugate == True and transpose == False is not supported.");
149     }
150     return CUSPARSE_OPERATION_NON_TRANSPOSE;
151   }
152 #elif TENSORFLOW_USE_ROCM
153   if (transpose) {
154     return conjugate ? HIPSPARSE_OPERATION_CONJUGATE_TRANSPOSE
155                      : HIPSPARSE_OPERATION_TRANSPOSE;
156   } else {
157     if (conjugate) {
158       DCHECK(status != nullptr);
159       *status = errors::InvalidArgument(
160           "Conjugate == True and transpose == False is not supported.");
161     }
162     return HIPSPARSE_OPERATION_NON_TRANSPOSE;
163   }
164 #endif
165 }
166 
167 // The GpuSparse class provides a simplified templated API for cuSparse
168 // (http://docs.nvidia.com/cuda/cusparse/index.html).
169 // An object of this class wraps static cuSparse instances,
170 // and will launch Cuda kernels on the stream wrapped by the GPU device
171 // in the OpKernelContext provided to the constructor.
172 //
173 // Notice: All the computational member functions are asynchronous and simply
174 // launch one or more Cuda kernels on the Cuda stream wrapped by the GpuSparse
175 // object.
176 
177 class GpuSparse {
178  public:
179   // This object stores a pointer to context, which must outlive it.
180   explicit GpuSparse(OpKernelContext* context);
~GpuSparse()181   virtual ~GpuSparse() {}
182 
183   // This initializes the GpuSparse class if it hasn't
184   // been initialized yet.  All following public methods require the
185   // class has been initialized.  Can be run multiple times; all
186   // subsequent calls after the first have no effect.
187   Status Initialize();  // Move to constructor?
188 
189   // ====================================================================
190   // Wrappers for cuSparse start here.
191   //
192 
193   // Solves tridiagonal system of equations.
194   // Note: Cuda Toolkit 9.0+ has better-performing gtsv2 routine. gtsv will be
195   // removed in Cuda Toolkit 11.0.
196   // See: https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-gtsv
197   // Returns Status::OK() if the kernel was launched successfully.
198   template <typename Scalar>
199   Status Gtsv(int m, int n, const Scalar *dl, const Scalar *d, const Scalar *du,
200               Scalar *B, int ldb) const;
201 
202   // Solves tridiagonal system of equations without pivoting.
203   // Note: Cuda Toolkit 9.0+ has better-performing gtsv2_nopivot routine.
204   // gtsv_nopivot will be removed in Cuda Toolkit 11.0.
205   // See:
206   // https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-gtsv_nopivot
207   // Returns Status::OK() if the kernel was launched successfully.
208   template <typename Scalar>
209   Status GtsvNoPivot(int m, int n, const Scalar *dl, const Scalar *d,
210                      const Scalar *du, Scalar *B, int ldb) const;
211 
212   // Solves a batch of tridiagonal systems of equations. Doesn't support
213   // multiple right-hand sides per each system. Doesn't do pivoting.
214   // Note: Cuda Toolkit 9.0+ has better-performing gtsv2StridedBatch routine.
215   // gtsvStridedBatch will be removed in Cuda Toolkit 11.0.
216   // See:
217   // https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-gtsvstridedbatch
218   // Returns Status::OK() if the kernel was launched successfully.
219   template <typename Scalar>
220   Status GtsvStridedBatch(int m, const Scalar *dl, const Scalar *d,
221                           const Scalar *du, Scalar *x, int batchCount,
222                           int batchStride) const;
223 
224   // Solves tridiagonal system of equations.
225   // See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2
226   template <typename Scalar>
227   Status Gtsv2(int m, int n, const Scalar *dl, const Scalar *d,
228                const Scalar *du, Scalar *B, int ldb, void *pBuffer) const;
229 
230   // Computes the size of a temporary buffer used by Gtsv2.
231   // See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2_bufferSize
232   template <typename Scalar>
233   Status Gtsv2BufferSizeExt(int m, int n, const Scalar *dl, const Scalar *d,
234                             const Scalar *du, const Scalar *B, int ldb,
235                             size_t *bufferSizeInBytes) const;
236 
237   // Solves tridiagonal system of equations without partial pivoting.
238   // See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2_nopivot
239   template <typename Scalar>
240   Status Gtsv2NoPivot(int m, int n, const Scalar *dl, const Scalar *d,
241                       const Scalar *du, Scalar *B, int ldb,
242                       void *pBuffer) const;
243 
244   // Computes the size of a temporary buffer used by Gtsv2NoPivot.
245   // See:
246   // https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2_nopivot_bufferSize
247   template <typename Scalar>
248   Status Gtsv2NoPivotBufferSizeExt(int m, int n, const Scalar *dl,
249                                    const Scalar *d, const Scalar *du,
250                                    const Scalar *B, int ldb,
251                                    size_t *bufferSizeInBytes) const;
252 
253   // Solves a batch of tridiagonal systems of equations. Doesn't support
254   // multiple right-hand sides per each system. Doesn't do pivoting.
255   // See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2stridedbatch
256   template <typename Scalar>
257   Status Gtsv2StridedBatch(int m, const Scalar *dl, const Scalar *d,
258                            const Scalar *du, Scalar *x, int batchCount,
259                            int batchStride, void *pBuffer) const;
260 
261   // Computes the size of a temporary buffer used by Gtsv2StridedBatch.
262   // See:
263   // https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2stridedbatch_bufferSize
264   template <typename Scalar>
265   Status Gtsv2StridedBatchBufferSizeExt(int m, const Scalar *dl,
266                                         const Scalar *d, const Scalar *du,
267                                         const Scalar *x, int batchCount,
268                                         int batchStride,
269                                         size_t *bufferSizeInBytes) const;
270 
271   // Compresses the indices of rows or columns. It can be interpreted as a
272   // conversion from COO to CSR sparse storage format. See:
273   // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csr2coo.
274   Status Csr2coo(const int* CsrRowPtr, int nnz, int m, int* cooRowInd) const;
275 
276   // Uncompresses the indices of rows or columns. It can be interpreted as a
277   // conversion from CSR to COO sparse storage format. See:
278   // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-coo2csr.
279   Status Coo2csr(const int* cooRowInd, int nnz, int m, int* csrRowPtr) const;
280 
281   // Sparse-dense matrix multiplication C = alpha * op(A) * op(B)  + beta * C,
282   // where A is a sparse matrix in CSR format, B and C are dense tall
283   // matrices.  This routine allows transposition of matrix B, which
284   // may improve performance.  See:
285   // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrmm2
286   //
287   // **NOTE** Matrices B and C are expected to be in column-major
288   // order; to make them consistent with TensorFlow they
289   // must be transposed (or the matmul op's pre/post-procesisng must take this
290   // into account).
291   //
292   // **NOTE** This is an in-place operation for data in C.
293   template <typename Scalar>
294   Status Csrmm(gpusparseOperation_t transA, gpusparseOperation_t transB, int m,
295                int n, int k, int nnz, const Scalar* alpha_host,
296                const gpusparseMatDescr_t descrA, const Scalar* csrSortedValA,
297                const int* csrSortedRowPtrA, const int* csrSortedColIndA,
298                const Scalar* B, int ldb, const Scalar* beta_host, Scalar* C,
299                int ldc) const;
300 
301   // Sparse-dense vector multiplication y = alpha * op(A) * x  + beta * y,
302   // where A is a sparse matrix in CSR format, x and y are dense vectors. See:
303   // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrmv_mergepath
304   //
305   // **NOTE** This is an in-place operation for data in y.
306   template <typename Scalar>
307   Status Csrmv(gpusparseOperation_t transA, int m, int n, int nnz,
308                const Scalar* alpha_host, const gpusparseMatDescr_t descrA,
309                const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
310                const int* csrSortedColIndA, const Scalar* x,
311                const Scalar* beta_host, Scalar* y) const;
312 
313   // Computes sparse-sparse matrix addition of matrices
314   // stored in CSR format.  This is part one: calculate nnz of the
315   // output.  csrSortedRowPtrC must be preallocated on device with
316   // m + 1 entries.  See:
317   // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgeam.
318   Status CsrgeamNnz(int m, int n, const gpusparseMatDescr_t descrA, int nnzA,
319                     const int* csrSortedRowPtrA, const int* csrSortedColIndA,
320                     const gpusparseMatDescr_t descrB, int nnzB,
321                     const int* csrSortedRowPtrB, const int* csrSortedColIndB,
322                     const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC,
323                     int* nnzTotalDevHostPtr);
324 
325   // Computes sparse - sparse matrix addition of matrices
326   // stored in CSR format.  This is part two: perform sparse-sparse
327   // addition.  csrValC and csrColIndC must be allocated on the device
328   // with nnzTotalDevHostPtr entries (as calculated by CsrgeamNnz).  See:
329   // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgeam.
330   template <typename Scalar>
331   Status Csrgeam(int m, int n, const Scalar* alpha,
332                  const gpusparseMatDescr_t descrA, int nnzA,
333                  const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
334                  const int* csrSortedColIndA, const Scalar* beta,
335                  const gpusparseMatDescr_t descrB, int nnzB,
336                  const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
337                  const int* csrSortedColIndB, const gpusparseMatDescr_t descrC,
338                  Scalar* csrSortedValC, int* csrSortedRowPtrC,
339                  int* csrSortedColIndC);
340 
341   // Computes sparse-sparse matrix multiplication of matrices
342   // stored in CSR format.  This is part one: calculate nnz of the
343   // output.  csrSortedRowPtrC must be preallocated on device with
344   // m + 1 entries.  See:
345   // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm.
346   Status CsrgemmNnz(gpusparseOperation_t transA, gpusparseOperation_t transB,
347                     int m, int k, int n, const gpusparseMatDescr_t descrA,
348                     int nnzA, const int* csrSortedRowPtrA,
349                     const int* csrSortedColIndA,
350                     const gpusparseMatDescr_t descrB, int nnzB,
351                     const int* csrSortedRowPtrB, const int* csrSortedColIndB,
352                     const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC,
353                     int* nnzTotalDevHostPtr);
354 
355   // Computes sparse - sparse matrix matmul of matrices
356   // stored in CSR format.  This is part two: perform sparse-sparse
357   // addition.  csrValC and csrColIndC must be allocated on the device
358   // with nnzTotalDevHostPtr entries (as calculated by CsrgemmNnz).  See:
359   // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm.
360   template <typename Scalar>
361   Status Csrgemm(gpusparseOperation_t transA, gpusparseOperation_t transB,
362                  int m, int k, int n, const gpusparseMatDescr_t descrA,
363                  int nnzA, const Scalar* csrSortedValA,
364                  const int* csrSortedRowPtrA, const int* csrSortedColIndA,
365                  const gpusparseMatDescr_t descrB, int nnzB,
366                  const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
367                  const int* csrSortedColIndB, const gpusparseMatDescr_t descrC,
368                  Scalar* csrSortedValC, int* csrSortedRowPtrC,
369                  int* csrSortedColIndC);
370 
371   // In-place reordering of unsorted CSR to sorted CSR.
372   // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csru2csr
373   template <typename Scalar>
374   Status Csru2csr(int m, int n, int nnz, const gpusparseMatDescr_t descrA,
375                   Scalar* csrVal, const int* csrRowPtr, int* csrColInd);
376 
377   // Converts from CSR to CSC format (equivalently, transpose).
378   // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-csr2cscEx
379   template <typename Scalar>
380   Status Csr2csc(int m, int n, int nnz, const Scalar* csrVal,
381                  const int* csrRowPtr, const int* csrColInd, Scalar* cscVal,
382                  int* cscRowInd, int* cscColPtr,
383                  const gpusparseAction_t copyValues);
384 
385  private:
386   bool initialized_;
387   OpKernelContext *context_;  // not owned.
388   gpuStream_t gpu_stream_;
389   gpusparseHandle_t* gpusparse_handle_;  // not owned.
390 
391   TF_DISALLOW_COPY_AND_ASSIGN(GpuSparse);
392 };
393 
394 // A wrapper class to ensure that a CUDA sparse matrix descriptor is initialized
395 // only once. For more details on the descriptor (gpusparseMatDescr_t), see:
396 // https://docs.nvidia.com/cuda/cusparse/index.html#cusparsematdescrt
397 class GpuSparseMatrixDescriptor {
398  public:
GpuSparseMatrixDescriptor()399   explicit GpuSparseMatrixDescriptor() : initialized_(false) {}
400 
GpuSparseMatrixDescriptor(GpuSparseMatrixDescriptor && rhs)401   GpuSparseMatrixDescriptor(GpuSparseMatrixDescriptor&& rhs)
402       : initialized_(rhs.initialized_), descr_(std::move(rhs.descr_)) {
403     rhs.initialized_ = false;
404   }
405 
406   GpuSparseMatrixDescriptor& operator=(GpuSparseMatrixDescriptor&& rhs) {
407     if (this == &rhs) return *this;
408     Release();
409     initialized_ = rhs.initialized_;
410     descr_ = std::move(rhs.descr_);
411     rhs.initialized_ = false;
412     return *this;
413   }
414 
~GpuSparseMatrixDescriptor()415   ~GpuSparseMatrixDescriptor() { Release(); }
416 
417   // Initializes the underlying descriptor.  Will fail on the second call if
418   // called more than once.
Initialize()419   Status Initialize() {
420     DCHECK(!initialized_);
421 #if GOOGLE_CUDA
422     TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descr_));
423 #elif TENSORFLOW_USE_ROCM
424     TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descr_));
425 #endif
426     initialized_ = true;
427     return Status::OK();
428   }
429 
descr()430   gpusparseMatDescr_t& descr() {
431     DCHECK(initialized_);
432     return descr_;
433   }
434 
descr()435   const gpusparseMatDescr_t& descr() const {
436     DCHECK(initialized_);
437     return descr_;
438   }
439 
440  private:
Release()441   void Release() {
442     if (initialized_) {
443 #if GOOGLE_CUDA
444       cusparseDestroyMatDescr(descr_);
445 #elif TENSORFLOW_USE_ROCM
446       hipsparseDestroyMatDescr(descr_);
447 #endif
448       initialized_ = false;
449     }
450   }
451 
452   bool initialized_;
453   gpusparseMatDescr_t descr_;
454 
455   TF_DISALLOW_COPY_AND_ASSIGN(GpuSparseMatrixDescriptor);
456 };
457 
458 #if GOOGLE_CUDA
459 
460 // A wrapper class to ensure that an unsorted/sorted CSR conversion information
461 // struct (csru2csrInfo_t) is initialized only once. See:
462 // https://docs.nvidia.com/cuda/cusparse/index.html#csru2csr
463 class GpuSparseCsrSortingConversionInfo {
464  public:
GpuSparseCsrSortingConversionInfo()465   explicit GpuSparseCsrSortingConversionInfo() : initialized_(false) {}
466 
GpuSparseCsrSortingConversionInfo(GpuSparseCsrSortingConversionInfo && rhs)467   GpuSparseCsrSortingConversionInfo(GpuSparseCsrSortingConversionInfo&& rhs)
468       : initialized_(rhs.initialized_), info_(std::move(rhs.info_)) {
469     rhs.initialized_ = false;
470   }
471 
472   GpuSparseCsrSortingConversionInfo& operator=(
473       GpuSparseCsrSortingConversionInfo&& rhs) {
474     if (this == &rhs) return *this;
475     Release();
476     initialized_ = rhs.initialized_;
477     info_ = std::move(rhs.info_);
478     rhs.initialized_ = false;
479     return *this;
480   }
481 
~GpuSparseCsrSortingConversionInfo()482   ~GpuSparseCsrSortingConversionInfo() { Release(); }
483 
484   // Initializes the underlying info. Will fail on the second call if called
485   // more than once.
Initialize()486   Status Initialize() {
487     DCHECK(!initialized_);
488     TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsru2csrInfo(&info_));
489     initialized_ = true;
490     return Status::OK();
491   }
492 
info()493   csru2csrInfo_t& info() {
494     DCHECK(initialized_);
495     return info_;
496   }
497 
info()498   const csru2csrInfo_t& info() const {
499     DCHECK(initialized_);
500     return info_;
501   }
502 
503  private:
Release()504   void Release() {
505     if (initialized_) {
506       cusparseDestroyCsru2csrInfo(info_);
507       initialized_ = false;
508     }
509   }
510 
511   bool initialized_;
512   csru2csrInfo_t info_;
513 
514   TF_DISALLOW_COPY_AND_ASSIGN(GpuSparseCsrSortingConversionInfo);
515 };
516 
517 #endif  // GOOGLE_CUDA
518 
519 }  // namespace tensorflow
520 
521 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
522 
523 #endif  // TENSORFLOW_CORE_KERNELS_CUDA_SPARSE_H_
524