1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_LINALG_CUDA_SPARSE_H_
17 #define TENSORFLOW_CORE_KERNELS_LINALG_CUDA_SPARSE_H_
18 
19 // This header declares the class GpuSparse, which contains wrappers of
20 // cuSparse libraries for use in TensorFlow kernels.
21 
22 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
23 
24 #include <functional>
25 #include <vector>
26 
27 #if GOOGLE_CUDA
28 
29 #include "third_party/gpus/cuda/include/cuda.h"
30 #include "third_party/gpus/cuda/include/cusparse.h"
31 
32 using gpusparseStatus_t = cusparseStatus_t;
33 using gpusparseOperation_t = cusparseOperation_t;
34 using gpusparseMatDescr_t = cusparseMatDescr_t;
35 using gpusparseAction_t = cusparseAction_t;
36 using gpusparseHandle_t = cusparseHandle_t;
37 using gpuStream_t = cudaStream_t;
38 #if CUDA_VERSION >= 10020
39 using gpusparseDnMatDescr_t = cusparseDnMatDescr_t;
40 using gpusparseSpMatDescr_t = cusparseSpMatDescr_t;
41 using gpusparseSpMMAlg_t = cusparseSpMMAlg_t;
42 #endif
43 
44 #define GPUSPARSE(postfix) CUSPARSE_##postfix
45 #define gpusparse(postfix) cusparse##postfix
46 
47 #elif TENSORFLOW_USE_ROCM
48 
49 #include "tensorflow/stream_executor/rocm/hipsparse_wrapper.h"
50 
51 using gpusparseStatus_t = hipsparseStatus_t;
52 using gpusparseOperation_t = hipsparseOperation_t;
53 using gpusparseMatDescr_t = hipsparseMatDescr_t;
54 using gpusparseAction_t = hipsparseAction_t;
55 using gpusparseHandle_t = hipsparseHandle_t;
56 using gpuStream_t = hipStream_t;
57 #if TF_ROCM_VERSION >= 40200
58 using gpusparseDnMatDescr_t = hipsparseDnMatDescr_t;
59 using gpusparseSpMatDescr_t = hipsparseSpMatDescr_t;
60 using gpusparseSpMMAlg_t = hipsparseSpMMAlg_t;
61 #endif
62 #define GPUSPARSE(postfix) HIPSPARSE_##postfix
63 #define gpusparse(postfix) hipsparse##postfix
64 
65 #endif
66 
67 #include "tensorflow/core/framework/op_kernel.h"
68 #include "tensorflow/core/framework/tensor.h"
69 #include "tensorflow/core/framework/tensor_types.h"
70 #include "tensorflow/core/lib/core/status.h"
71 #include "tensorflow/core/platform/stream_executor.h"
72 #include "tensorflow/core/public/version.h"
73 
74 // Macro that specializes a sparse method for all 4 standard
75 // numeric types.
76 // TODO: reuse with cuda_solvers
77 #define TF_CALL_LAPACK_TYPES(m) \
78   m(float, S) m(double, D) m(std::complex<float>, C) m(std::complex<double>, Z)
79 
80 namespace tensorflow {
81 
ConvertGPUSparseErrorToString(const gpusparseStatus_t status)82 inline std::string ConvertGPUSparseErrorToString(
83     const gpusparseStatus_t status) {
84   switch (status) {
85 #define STRINGIZE(q) #q
86 #define RETURN_IF_STATUS(err) \
87   case err:                   \
88     return STRINGIZE(err);
89 
90 #if GOOGLE_CUDA
91 
92     RETURN_IF_STATUS(CUSPARSE_STATUS_SUCCESS)
93     RETURN_IF_STATUS(CUSPARSE_STATUS_NOT_INITIALIZED)
94     RETURN_IF_STATUS(CUSPARSE_STATUS_ALLOC_FAILED)
95     RETURN_IF_STATUS(CUSPARSE_STATUS_INVALID_VALUE)
96     RETURN_IF_STATUS(CUSPARSE_STATUS_ARCH_MISMATCH)
97     RETURN_IF_STATUS(CUSPARSE_STATUS_MAPPING_ERROR)
98     RETURN_IF_STATUS(CUSPARSE_STATUS_EXECUTION_FAILED)
99     RETURN_IF_STATUS(CUSPARSE_STATUS_INTERNAL_ERROR)
100     RETURN_IF_STATUS(CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED)
101 
102     default:
103       return strings::StrCat("Unknown CUSPARSE error: ",
104                              static_cast<int>(status));
105 #elif TENSORFLOW_USE_ROCM
106 
107     RETURN_IF_STATUS(HIPSPARSE_STATUS_SUCCESS)
108     RETURN_IF_STATUS(HIPSPARSE_STATUS_NOT_INITIALIZED)
109     RETURN_IF_STATUS(HIPSPARSE_STATUS_ALLOC_FAILED)
110     RETURN_IF_STATUS(HIPSPARSE_STATUS_INVALID_VALUE)
111     RETURN_IF_STATUS(HIPSPARSE_STATUS_ARCH_MISMATCH)
112     RETURN_IF_STATUS(HIPSPARSE_STATUS_MAPPING_ERROR)
113     RETURN_IF_STATUS(HIPSPARSE_STATUS_EXECUTION_FAILED)
114     RETURN_IF_STATUS(HIPSPARSE_STATUS_INTERNAL_ERROR)
115     RETURN_IF_STATUS(HIPSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED)
116     RETURN_IF_STATUS(HIPSPARSE_STATUS_ZERO_PIVOT)
117 
118     default:
119       return strings::StrCat("Unknown hipSPARSE error: ",
120                              static_cast<int>(status));
121 #endif
122 
123 #undef RETURN_IF_STATUS
124 #undef STRINGIZE
125   }
126 }
127 
128 #if GOOGLE_CUDA
129 
130 #define TF_RETURN_IF_GPUSPARSE_ERROR(expr)                                 \
131   do {                                                                     \
132     auto status = (expr);                                                  \
133     if (TF_PREDICT_FALSE(status != CUSPARSE_STATUS_SUCCESS)) {             \
134       return errors::Internal(__FILE__, ":", __LINE__, " (", TF_STR(expr), \
135                               "): cuSparse call failed with status ",      \
136                               ConvertGPUSparseErrorToString(status));      \
137     }                                                                      \
138   } while (0)
139 
140 #elif TENSORFLOW_USE_ROCM
141 
142 #define TF_RETURN_IF_GPUSPARSE_ERROR(expr)                                 \
143   do {                                                                     \
144     auto status = (expr);                                                  \
145     if (TF_PREDICT_FALSE(status != HIPSPARSE_STATUS_SUCCESS)) {            \
146       return errors::Internal(__FILE__, ":", __LINE__, " (", TF_STR(expr), \
147                               "): hipSPARSE call failed with status ",     \
148                               ConvertGPUSparseErrorToString(status));      \
149     }                                                                      \
150   } while (0)
151 
152 #endif
153 
TransposeAndConjugateToGpuSparseOp(bool transpose,bool conjugate,Status * status)154 inline gpusparseOperation_t TransposeAndConjugateToGpuSparseOp(bool transpose,
155                                                                bool conjugate,
156                                                                Status* status) {
157 #if GOOGLE_CUDA
158   if (transpose) {
159     return conjugate ? CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE
160                      : CUSPARSE_OPERATION_TRANSPOSE;
161   } else {
162     if (conjugate) {
163       DCHECK(status != nullptr);
164       *status = errors::InvalidArgument(
165           "Conjugate == True and transpose == False is not supported.");
166     }
167     return CUSPARSE_OPERATION_NON_TRANSPOSE;
168   }
169 #elif TENSORFLOW_USE_ROCM
170   if (transpose) {
171     return conjugate ? HIPSPARSE_OPERATION_CONJUGATE_TRANSPOSE
172                      : HIPSPARSE_OPERATION_TRANSPOSE;
173   } else {
174     if (conjugate) {
175       DCHECK(status != nullptr);
176       *status = errors::InvalidArgument(
177           "Conjugate == True and transpose == False is not supported.");
178     }
179     return HIPSPARSE_OPERATION_NON_TRANSPOSE;
180   }
181 #endif
182 }
183 
184 // The GpuSparse class provides a simplified templated API for cuSparse
185 // (http://docs.nvidia.com/cuda/cusparse/index.html).
186 // An object of this class wraps static cuSparse instances,
187 // and will launch Cuda kernels on the stream wrapped by the GPU device
188 // in the OpKernelContext provided to the constructor.
189 //
190 // Notice: All the computational member functions are asynchronous and simply
191 // launch one or more Cuda kernels on the Cuda stream wrapped by the GpuSparse
192 // object.
193 
194 class GpuSparse {
195  public:
196   // This object stores a pointer to context, which must outlive it.
197   explicit GpuSparse(OpKernelContext* context);
~GpuSparse()198   virtual ~GpuSparse() {}
199 
200   // This initializes the GpuSparse class if it hasn't
201   // been initialized yet.  All following public methods require the
202   // class has been initialized.  Can be run multiple times; all
203   // subsequent calls after the first have no effect.
204   Status Initialize();  // Move to constructor?
205 
206   // ====================================================================
207   // Wrappers for cuSparse start here.
208   //
209 
210   // Solves tridiagonal system of equations.
211   // See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2
212   template <typename Scalar>
213   Status Gtsv2(int m, int n, const Scalar* dl, const Scalar* d,
214                const Scalar* du, Scalar* B, int ldb, void* pBuffer) const;
215 
216   // Computes the size of a temporary buffer used by Gtsv2.
217   // See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2_bufferSize
218   template <typename Scalar>
219   Status Gtsv2BufferSizeExt(int m, int n, const Scalar* dl, const Scalar* d,
220                             const Scalar* du, const Scalar* B, int ldb,
221                             size_t* bufferSizeInBytes) const;
222 
223   // Solves tridiagonal system of equations without partial pivoting.
224   // See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2_nopivot
225   template <typename Scalar>
226   Status Gtsv2NoPivot(int m, int n, const Scalar* dl, const Scalar* d,
227                       const Scalar* du, Scalar* B, int ldb,
228                       void* pBuffer) const;
229 
230   // Computes the size of a temporary buffer used by Gtsv2NoPivot.
231   // See:
232   // https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2_nopivot_bufferSize
233   template <typename Scalar>
234   Status Gtsv2NoPivotBufferSizeExt(int m, int n, const Scalar* dl,
235                                    const Scalar* d, const Scalar* du,
236                                    const Scalar* B, int ldb,
237                                    size_t* bufferSizeInBytes) const;
238 
239   // Solves a batch of tridiagonal systems of equations. Doesn't support
240   // multiple right-hand sides per each system. Doesn't do pivoting.
241   // See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2stridedbatch
242   template <typename Scalar>
243   Status Gtsv2StridedBatch(int m, const Scalar* dl, const Scalar* d,
244                            const Scalar* du, Scalar* x, int batchCount,
245                            int batchStride, void* pBuffer) const;
246 
247   // Computes the size of a temporary buffer used by Gtsv2StridedBatch.
248   // See:
249   // https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2stridedbatch_bufferSize
250   template <typename Scalar>
251   Status Gtsv2StridedBatchBufferSizeExt(int m, const Scalar* dl,
252                                         const Scalar* d, const Scalar* du,
253                                         const Scalar* x, int batchCount,
254                                         int batchStride,
255                                         size_t* bufferSizeInBytes) const;
256 
257   // Compresses the indices of rows or columns. It can be interpreted as a
258   // conversion from COO to CSR sparse storage format. See:
259   // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csr2coo.
260   Status Csr2coo(const int* CsrRowPtr, int nnz, int m, int* cooRowInd) const;
261 
262   // Uncompresses the indices of rows or columns. It can be interpreted as a
263   // conversion from CSR to COO sparse storage format. See:
264   // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-coo2csr.
265   Status Coo2csr(const int* cooRowInd, int nnz, int m, int* csrRowPtr) const;
266 
267 #if (GOOGLE_CUDA && (CUDA_VERSION < 10020)) || \
268     (TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 40200)
269   // Sparse-dense matrix multiplication C = alpha * op(A) * op(B)  + beta * C,
270   // where A is a sparse matrix in CSR format, B and C are dense tall
271   // matrices.  This routine allows transposition of matrix B, which
272   // may improve performance.  See:
273   // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrmm2
274   //
275   // **NOTE** Matrices B and C are expected to be in column-major
276   // order; to make them consistent with TensorFlow they
277   // must be transposed (or the matmul op's pre/post-processing must take this
278   // into account).
279   //
280   // **NOTE** This is an in-place operation for data in C.
281   template <typename Scalar>
282   Status Csrmm(gpusparseOperation_t transA, gpusparseOperation_t transB, int m,
283                int n, int k, int nnz, const Scalar* alpha_host,
284                const gpusparseMatDescr_t descrA, const Scalar* csrSortedValA,
285                const int* csrSortedRowPtrA, const int* csrSortedColIndA,
286                const Scalar* B, int ldb, const Scalar* beta_host, Scalar* C,
287                int ldc) const;
288 #else  // CUDA_VERSION >=10200 || TF_ROCM_VERSION >= 40200
289   // Workspace size query for sparse-dense matrix multiplication. Helper
290   // function for SpMM which computes y = alpha * op(A) * op(B) + beta * C,
291   // where A is a sparse matrix in CSR format, B and C are dense matricies in
292   // column-major format. Returns needed workspace size in bytes.
293   template <typename Scalar>
294   Status SpMMBufferSize(gpusparseOperation_t transA,
295                         gpusparseOperation_t transB, const Scalar* alpha,
296                         const gpusparseSpMatDescr_t matA,
297                         const gpusparseDnMatDescr_t matB, const Scalar* beta,
298                         gpusparseDnMatDescr_t matC, gpusparseSpMMAlg_t alg,
299                         size_t* bufferSize) const;
300 
301   // Sparse-dense matrix multiplication y = alpha * op(A) * op(B) + beta * C,
302   // where A is a sparse matrix in CSR format, B and C are dense matricies in
303   // column-major format. Buffer is assumed to be at least as large as the
304   // workspace size returned by SpMMBufferSize().
305   //
306   // **NOTE** This is an in-place operation for data in C.
307   template <typename Scalar>
308   Status SpMM(gpusparseOperation_t transA, gpusparseOperation_t transB,
309               const Scalar* alpha, const gpusparseSpMatDescr_t matA,
310               const gpusparseDnMatDescr_t matB, const Scalar* beta,
311               gpusparseDnMatDescr_t matC, gpusparseSpMMAlg_t alg,
312               int8* buffer) const;
313 #endif
314 
315   // Sparse-dense vector multiplication y = alpha * op(A) * x  + beta * y,
316   // where A is a sparse matrix in CSR format, x and y are dense vectors. See:
317   // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrmv_mergepath
318   //
319   // **NOTE** This is an in-place operation for data in y.
320 #if (GOOGLE_CUDA && (CUDA_VERSION < 10020)) || TENSORFLOW_USE_ROCM
321   template <typename Scalar>
322   Status Csrmv(gpusparseOperation_t transA, int m, int n, int nnz,
323                const Scalar* alpha_host, const gpusparseMatDescr_t descrA,
324                const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
325                const int* csrSortedColIndA, const Scalar* x,
326                const Scalar* beta_host, Scalar* y) const;
327 #else
328   template <typename Scalar>
329   Status Csrmv(gpusparseOperation_t transA, int m, int n, int nnz,
330                const Scalar* alpha_host, const Scalar* csrSortedValA,
331                const int* csrSortedRowPtrA, const int* csrSortedColIndA,
332                const Scalar* x, const Scalar* beta_host, Scalar* y) const;
333 #endif  // CUDA_VERSION < 10020
334 
335   // Computes workspace size for sparse - sparse matrix addition of matrices
336   // stored in CSR format.
337   template <typename Scalar>
338   Status CsrgeamBufferSizeExt(
339       int m, int n, const Scalar* alpha, const gpusparseMatDescr_t descrA,
340       int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
341       const int* csrSortedColIndA, const Scalar* beta,
342       const gpusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB,
343       const int* csrSortedRowPtrB, const int* csrSortedColIndB,
344       const gpusparseMatDescr_t descrC, Scalar* csrSortedValC,
345       int* csrSortedRowPtrC, int* csrSortedColIndC, size_t* bufferSize);
346 
347   // Computes sparse-sparse matrix addition of matrices
348   // stored in CSR format.  This is part one: calculate nnz of the
349   // output.  csrSortedRowPtrC must be preallocated on device with
350   // m + 1 entries.  See:
351   // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgeam.
352   Status CsrgeamNnz(int m, int n, const gpusparseMatDescr_t descrA, int nnzA,
353                     const int* csrSortedRowPtrA, const int* csrSortedColIndA,
354                     const gpusparseMatDescr_t descrB, int nnzB,
355                     const int* csrSortedRowPtrB, const int* csrSortedColIndB,
356                     const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC,
357                     int* nnzTotalDevHostPtr, void* workspace);
358 
359   // Computes sparse - sparse matrix addition of matrices
360   // stored in CSR format.  This is part two: perform sparse-sparse
361   // addition.  csrValC and csrColIndC must be allocated on the device
362   // with nnzTotalDevHostPtr entries (as calculated by CsrgeamNnz).  See:
363   // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgeam.
364   template <typename Scalar>
365   Status Csrgeam(int m, int n, const Scalar* alpha,
366                  const gpusparseMatDescr_t descrA, int nnzA,
367                  const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
368                  const int* csrSortedColIndA, const Scalar* beta,
369                  const gpusparseMatDescr_t descrB, int nnzB,
370                  const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
371                  const int* csrSortedColIndB, const gpusparseMatDescr_t descrC,
372                  Scalar* csrSortedValC, int* csrSortedRowPtrC,
373                  int* csrSortedColIndC, void* workspace);
374 
375 #if GOOGLE_CUDA && (CUDA_VERSION >= 10000)
376   // Computes sparse-sparse matrix multiplication of matrices
377   // stored in CSR format.  This is part zero: calculate required workspace
378   // size.
379   template <typename Scalar>
380   Status CsrgemmBufferSize(
381       int m, int n, int k, const gpusparseMatDescr_t descrA, int nnzA,
382       const int* csrSortedRowPtrA, const int* csrSortedColIndA,
383       const gpusparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB,
384       const int* csrSortedColIndB, csrgemm2Info_t info, size_t* workspaceBytes);
385 #endif
386 
387   // Computes sparse-sparse matrix multiplication of matrices
388   // stored in CSR format.  This is part one: calculate nnz of the
389   // output.  csrSortedRowPtrC must be preallocated on device with
390   // m + 1 entries.  See:
391   // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm.
392 #if (GOOGLE_CUDA && (CUDA_VERSION < 10000)) || TENSORFLOW_USE_ROCM
393   Status CsrgemmNnz(gpusparseOperation_t transA, gpusparseOperation_t transB,
394                     int m, int k, int n, const gpusparseMatDescr_t descrA,
395                     int nnzA, const int* csrSortedRowPtrA,
396                     const int* csrSortedColIndA,
397                     const gpusparseMatDescr_t descrB, int nnzB,
398                     const int* csrSortedRowPtrB, const int* csrSortedColIndB,
399                     const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC,
400                     int* nnzTotalDevHostPtr);
401 #else
402   Status CsrgemmNnz(int m, int n, int k, const gpusparseMatDescr_t descrA,
403                     int nnzA, const int* csrSortedRowPtrA,
404                     const int* csrSortedColIndA,
405                     const gpusparseMatDescr_t descrB, int nnzB,
406                     const int* csrSortedRowPtrB, const int* csrSortedColIndB,
407                     const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC,
408                     int* nnzTotalDevHostPtr, csrgemm2Info_t info,
409                     void* workspace);
410 #endif
411 
412   // Computes sparse - sparse matrix matmul of matrices
413   // stored in CSR format.  This is part two: perform sparse-sparse
414   // addition.  csrValC and csrColIndC must be allocated on the device
415   // with nnzTotalDevHostPtr entries (as calculated by CsrgemmNnz).  See:
416   // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm.
417 #if (GOOGLE_CUDA && (CUDA_VERSION < 10000)) || TENSORFLOW_USE_ROCM
418   template <typename Scalar>
419   Status Csrgemm(gpusparseOperation_t transA, gpusparseOperation_t transB,
420                  int m, int k, int n, const gpusparseMatDescr_t descrA,
421                  int nnzA, const Scalar* csrSortedValA,
422                  const int* csrSortedRowPtrA, const int* csrSortedColIndA,
423                  const gpusparseMatDescr_t descrB, int nnzB,
424                  const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
425                  const int* csrSortedColIndB, const gpusparseMatDescr_t descrC,
426                  Scalar* csrSortedValC, int* csrSortedRowPtrC,
427                  int* csrSortedColIndC);
428 #else
429   template <typename Scalar>
430   Status Csrgemm(int m, int n, int k, const gpusparseMatDescr_t descrA,
431                  int nnzA, const Scalar* csrSortedValA,
432                  const int* csrSortedRowPtrA, const int* csrSortedColIndA,
433                  const gpusparseMatDescr_t descrB, int nnzB,
434                  const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
435                  const int* csrSortedColIndB, const gpusparseMatDescr_t descrC,
436                  Scalar* csrSortedValC, int* csrSortedRowPtrC,
437                  int* csrSortedColIndC, const csrgemm2Info_t info,
438                  void* workspace);
439 #endif
440 
441   // In-place reordering of unsorted CSR to sorted CSR.
442   // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csru2csr
443   template <typename Scalar>
444   Status Csru2csr(int m, int n, int nnz, const gpusparseMatDescr_t descrA,
445                   Scalar* csrVal, const int* csrRowPtr, int* csrColInd);
446 
447   // Converts from CSR to CSC format (equivalently, transpose).
448   // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-csr2cscEx
449   template <typename Scalar>
450   Status Csr2csc(int m, int n, int nnz, const Scalar* csrVal,
451                  const int* csrRowPtr, const int* csrColInd, Scalar* cscVal,
452                  int* cscRowInd, int* cscColPtr,
453                  const gpusparseAction_t copyValues);
454 
455  private:
456   bool initialized_;
457   OpKernelContext* context_;  // not owned.
458   gpuStream_t gpu_stream_;
459   gpusparseHandle_t* gpusparse_handle_;  // not owned.
460 
461   TF_DISALLOW_COPY_AND_ASSIGN(GpuSparse);
462 };
463 
464 // A wrapper class to ensure that a CUDA sparse matrix descriptor is initialized
465 // only once. For more details on the descriptor (gpusparseMatDescr_t), see:
466 // https://docs.nvidia.com/cuda/cusparse/index.html#cusparsematdescrt
467 class GpuSparseMatrixDescriptor {
468  public:
GpuSparseMatrixDescriptor()469   explicit GpuSparseMatrixDescriptor() : initialized_(false) {}
470 
GpuSparseMatrixDescriptor(GpuSparseMatrixDescriptor && rhs)471   GpuSparseMatrixDescriptor(GpuSparseMatrixDescriptor&& rhs)
472       : initialized_(rhs.initialized_), descr_(std::move(rhs.descr_)) {
473     rhs.initialized_ = false;
474   }
475 
476   GpuSparseMatrixDescriptor& operator=(GpuSparseMatrixDescriptor&& rhs) {
477     if (this == &rhs) return *this;
478     Release();
479     initialized_ = rhs.initialized_;
480     descr_ = std::move(rhs.descr_);
481     rhs.initialized_ = false;
482     return *this;
483   }
484 
~GpuSparseMatrixDescriptor()485   ~GpuSparseMatrixDescriptor() { Release(); }
486 
487   // Initializes the underlying descriptor.  Will fail on the second call if
488   // called more than once.
Initialize()489   Status Initialize() {
490     DCHECK(!initialized_);
491 #if GOOGLE_CUDA
492     TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descr_));
493 #elif TENSORFLOW_USE_ROCM
494     TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateMatDescr(&descr_));
495 #endif
496     initialized_ = true;
497     return Status::OK();
498   }
499 
descr()500   gpusparseMatDescr_t& descr() {
501     DCHECK(initialized_);
502     return descr_;
503   }
504 
descr()505   const gpusparseMatDescr_t& descr() const {
506     DCHECK(initialized_);
507     return descr_;
508   }
509 
510  private:
Release()511   void Release() {
512     if (initialized_) {
513 #if GOOGLE_CUDA
514       cusparseDestroyMatDescr(descr_);
515 #elif TENSORFLOW_USE_ROCM
516       wrap::hipsparseDestroyMatDescr(descr_);
517 #endif
518       initialized_ = false;
519     }
520   }
521 
522   bool initialized_;
523   gpusparseMatDescr_t descr_;
524 
525   TF_DISALLOW_COPY_AND_ASSIGN(GpuSparseMatrixDescriptor);
526 };
527 
528 #if GOOGLE_CUDA
529 
530 // A wrapper class to ensure that an unsorted/sorted CSR conversion information
531 // struct (csru2csrInfo_t) is initialized only once. See:
532 // https://docs.nvidia.com/cuda/cusparse/index.html#csru2csr
533 class GpuSparseCsrSortingConversionInfo {
534  public:
GpuSparseCsrSortingConversionInfo()535   explicit GpuSparseCsrSortingConversionInfo() : initialized_(false) {}
536 
GpuSparseCsrSortingConversionInfo(GpuSparseCsrSortingConversionInfo && rhs)537   GpuSparseCsrSortingConversionInfo(GpuSparseCsrSortingConversionInfo&& rhs)
538       : initialized_(rhs.initialized_), info_(std::move(rhs.info_)) {
539     rhs.initialized_ = false;
540   }
541 
542   GpuSparseCsrSortingConversionInfo& operator=(
543       GpuSparseCsrSortingConversionInfo&& rhs) {
544     if (this == &rhs) return *this;
545     Release();
546     initialized_ = rhs.initialized_;
547     info_ = std::move(rhs.info_);
548     rhs.initialized_ = false;
549     return *this;
550   }
551 
~GpuSparseCsrSortingConversionInfo()552   ~GpuSparseCsrSortingConversionInfo() { Release(); }
553 
554   // Initializes the underlying info. Will fail on the second call if called
555   // more than once.
Initialize()556   Status Initialize() {
557     DCHECK(!initialized_);
558     TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsru2csrInfo(&info_));
559     initialized_ = true;
560     return Status::OK();
561   }
562 
info()563   csru2csrInfo_t& info() {
564     DCHECK(initialized_);
565     return info_;
566   }
567 
info()568   const csru2csrInfo_t& info() const {
569     DCHECK(initialized_);
570     return info_;
571   }
572 
573  private:
Release()574   void Release() {
575     if (initialized_) {
576       cusparseDestroyCsru2csrInfo(info_);
577       initialized_ = false;
578     }
579   }
580 
581   bool initialized_;
582   csru2csrInfo_t info_;
583 
584   TF_DISALLOW_COPY_AND_ASSIGN(GpuSparseCsrSortingConversionInfo);
585 };
586 
587 #endif  // GOOGLE_CUDA
588 
589 }  // namespace tensorflow
590 
591 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
592 
593 #endif  // TENSORFLOW_CORE_KERNELS_LINALG_CUDA_SPARSE_H_
594