• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifdef GOOGLE_CUDA
17 
18 #include "tensorflow/core/kernels/cuda_sparse.h"
19 
20 #include <complex>
21 #include <memory>
22 #include <unordered_map>
23 #include <utility>
24 #include <vector>
25 
26 #include "third_party/gpus/cuda/include/cusparse.h"
27 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/kernels/cuda_solvers.h"
31 #include "tensorflow/core/lib/core/blocking_counter.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/lib/core/stringpiece.h"
34 #include "tensorflow/core/lib/gtl/inlined_vector.h"
35 #include "tensorflow/core/lib/strings/strcat.h"
36 #include "tensorflow/core/platform/macros.h"
37 #include "tensorflow/core/platform/mutex.h"
38 #include "tensorflow/core/platform/stream_executor.h"
39 #include "tensorflow/core/platform/types.h"
40 
41 // TODO(rmlarsen,penporn): Investigate using newer kernels in CUDA 10.1+.
42 
43 namespace tensorflow {
44 namespace {
45 
46 // Type traits to get CUDA complex types from std::complex<>.
47 // TODO: reuse with cuda_solvers
48 template <typename T>
49 struct CudaComplexT {
50   typedef T type;
51 };
52 template <>
53 struct CudaComplexT<std::complex<float>> {
54   typedef cuComplex type;
55 };
56 template <>
57 struct CudaComplexT<std::complex<double>> {
58   typedef cuDoubleComplex type;
59 };
60 // Converts pointers of std::complex<> to pointers of
61 // cuComplex/cuDoubleComplex. No type conversion for non-complex types.
62 template <typename T>
AsCudaComplex(const T * p)63 inline const typename CudaComplexT<T>::type* AsCudaComplex(const T* p) {
64   return reinterpret_cast<const typename CudaComplexT<T>::type*>(p);
65 }
66 template <typename T>
AsCudaComplex(T * p)67 inline typename CudaComplexT<T>::type* AsCudaComplex(T* p) {
68   return reinterpret_cast<typename CudaComplexT<T>::type*>(p);
69 }
70 
71 // A set of initialized handles to the underlying Cuda libraries used by
72 // GpuSparse. We maintain one such set of handles per unique stream.
73 class CudaSparseHandles {
74  public:
CudaSparseHandles(cudaStream_t stream)75   explicit CudaSparseHandles(cudaStream_t stream)
76       : initialized_(false), stream_(stream) {}
77 
CudaSparseHandles(CudaSparseHandles && rhs)78   CudaSparseHandles(CudaSparseHandles&& rhs)
79       : initialized_(rhs.initialized_),
80         stream_(std::move(rhs.stream_)),
81         cusparse_handle_(rhs.cusparse_handle_) {
82     rhs.initialized_ = false;
83   }
84 
operator =(CudaSparseHandles && rhs)85   CudaSparseHandles& operator=(CudaSparseHandles&& rhs) {
86     if (this == &rhs) return *this;
87     Release();
88     stream_ = std::move(rhs.stream_);
89     cusparse_handle_ = std::move(rhs.cusparse_handle_);
90     initialized_ = rhs.initialized_;
91     rhs.initialized_ = false;
92     return *this;
93   }
94 
~CudaSparseHandles()95   ~CudaSparseHandles() { Release(); }
96 
Initialize()97   Status Initialize() {
98     if (initialized_) return Status::OK();
99     TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreate(&cusparse_handle_));
100     TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSetStream(cusparse_handle_, stream_));
101     initialized_ = true;
102     return Status::OK();
103   }
104 
handle()105   cusparseHandle_t& handle() {
106     DCHECK(initialized_);
107     return cusparse_handle_;
108   }
109 
handle() const110   const cusparseHandle_t& handle() const {
111     DCHECK(initialized_);
112     return cusparse_handle_;
113   }
114 
115  private:
Release()116   void Release() {
117     if (initialized_) {
118       // This should never return anything other than success
119       auto err = cusparseDestroy(cusparse_handle_);
120       DCHECK(err == CUSPARSE_STATUS_SUCCESS)
121           << "Failed to destroy cuSparse instance.";
122       initialized_ = false;
123     }
124   }
125   bool initialized_;
126   cudaStream_t stream_;
127   cusparseHandle_t cusparse_handle_;
128 
129   TF_DISALLOW_COPY_AND_ASSIGN(CudaSparseHandles);
130 };
131 
132 // TODO(ebrevdo): Replace global mutex guarding CudaSparseHandles
133 // lookup with one of:
134 //    1. Adding the handle to the CudaStream structure; do the lookup there.
135 //    2. Add a thread-local cusparse, set it to the current stream
136 //       upon each call.
137 // #1 seems like the cleanest option but will need to wait until this
138 // is moved into TF core.
139 static mutex handle_map_mutex(LINKER_INITIALIZED);
140 
141 using HandleMap = std::unordered_map<cudaStream_t, CudaSparseHandles>;
142 
143 // Returns a singleton map used for storing initialized handles for each unique
144 // cuda stream.
GetHandleMapSingleton()145 HandleMap* GetHandleMapSingleton() {
146   static HandleMap* cm = new HandleMap;
147   return cm;
148 }
149 
150 }  // namespace
151 
GpuSparse(OpKernelContext * context)152 GpuSparse::GpuSparse(OpKernelContext* context)
153     : initialized_(false), context_(context) {
154   auto cuda_stream_ptr =
155       reinterpret_cast<const cudaStream_t*>(context->op_device_context()
156                                                 ->stream()
157                                                 ->implementation()
158                                                 ->GpuStreamMemberHack());
159   DCHECK(cuda_stream_ptr);
160   gpu_stream_ = *cuda_stream_ptr;
161 }
162 
Initialize()163 Status GpuSparse::Initialize() {
164   HandleMap* handle_map = GetHandleMapSingleton();
165   DCHECK(handle_map);
166   mutex_lock lock(handle_map_mutex);
167   auto it = handle_map->find(gpu_stream_);
168   if (it == handle_map->end()) {
169     LOG(INFO) << "Creating CudaSparse handles for stream " << gpu_stream_;
170     // Previously unseen Cuda stream. Initialize a set of Cuda sparse library
171     // handles for it.
172     CudaSparseHandles new_handles(gpu_stream_);
173     TF_RETURN_IF_ERROR(new_handles.Initialize());
174     it = handle_map->insert(std::make_pair(gpu_stream_, std::move(new_handles)))
175              .first;
176   }
177   gpusparse_handle_ = &it->second.handle();
178   initialized_ = true;
179   return Status::OK();
180 }
181 
182 // Macro that specializes a sparse method for all 4 standard
183 // numeric types.
184 // TODO: reuse with cuda_solvers
185 #define TF_CALL_LAPACK_TYPES(m) \
186   m(float, S) m(double, D) m(std::complex<float>, C) m(std::complex<double>, Z)
187 
188 // Macros to construct cusparse method names.
189 #define SPARSE_FN(method, sparse_prefix) cusparse##sparse_prefix##method
190 #define SPARSE_NAME(method, sparse_prefix) "cusparse" #sparse_prefix #method
191 #define BUFSIZE_FN(method, sparse_prefix) \
192   cusparse##sparse_prefix##method##_bufferSizeExt
193 
194 //=============================================================================
195 // Wrappers of cuSparse computational methods begin here.
196 //
197 // WARNING to implementers: The function signatures listed in the online docs
198 // are sometimes inaccurate, e.g., are missing 'const' on pointers
199 // to immutable arguments, while the actual headers have them as expected.
200 // Check the actual declarations in the cusparse.h header file.
201 //=============================================================================
202 
203 template <typename Scalar, typename SparseFn>
GtsvImpl(SparseFn op,cusparseHandle_t cusparse_handle,int m,int n,const Scalar * dl,const Scalar * d,const Scalar * du,Scalar * B,int ldb)204 static inline Status GtsvImpl(SparseFn op, cusparseHandle_t cusparse_handle,
205                               int m, int n, const Scalar* dl, const Scalar* d,
206                               const Scalar* du, Scalar* B, int ldb) {
207   TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, n, AsCudaComplex(dl),
208                                   AsCudaComplex(d), AsCudaComplex(du),
209                                   AsCudaComplex(B), ldb));
210   return Status::OK();
211 }
212 
213 #define GTSV_INSTANCE(Scalar, sparse_prefix)                                   \
214   template <>                                                                  \
215   Status GpuSparse::Gtsv<Scalar>(int m, int n, const Scalar* dl,               \
216                                  const Scalar* d, const Scalar* du, Scalar* B, \
217                                  int ldb) const {                              \
218     DCHECK(initialized_);                                                      \
219     return GtsvImpl(SPARSE_FN(gtsv, sparse_prefix), *gpusparse_handle_, m, n,  \
220                     dl, d, du, B, ldb);                                        \
221   }
222 
223 TF_CALL_LAPACK_TYPES(GTSV_INSTANCE);
224 
225 #define GTSV_NO_PIVOT_INSTANCE(Scalar, sparse_prefix)                      \
226   template <>                                                              \
227   Status GpuSparse::GtsvNoPivot<Scalar>(int m, int n, const Scalar* dl,    \
228                                         const Scalar* d, const Scalar* du, \
229                                         Scalar* B, int ldb) const {        \
230     DCHECK(initialized_);                                                  \
231     return GtsvImpl(SPARSE_FN(gtsv_nopivot, sparse_prefix),                \
232                     *gpusparse_handle_, m, n, dl, d, du, B, ldb);          \
233   }
234 
235 TF_CALL_LAPACK_TYPES(GTSV_NO_PIVOT_INSTANCE);
236 
237 template <typename Scalar, typename SparseFn>
GtsvStridedBatchImpl(SparseFn op,cusparseHandle_t cusparse_handle,int m,const Scalar * dl,const Scalar * d,const Scalar * du,Scalar * x,int batchCount,int batchStride)238 static inline Status GtsvStridedBatchImpl(SparseFn op,
239                                           cusparseHandle_t cusparse_handle,
240                                           int m, const Scalar* dl,
241                                           const Scalar* d, const Scalar* du,
242                                           Scalar* x, int batchCount,
243                                           int batchStride) {
244   TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, AsCudaComplex(dl),
245                                   AsCudaComplex(d), AsCudaComplex(du),
246                                   AsCudaComplex(x), batchCount, batchStride));
247   return Status::OK();
248 }
249 
250 #define GTSV_STRIDED_BATCH_INSTANCE(Scalar, sparse_prefix)                   \
251   template <>                                                                \
252   Status GpuSparse::GtsvStridedBatch<Scalar>(                                \
253       int m, const Scalar* dl, const Scalar* d, const Scalar* du, Scalar* x, \
254       int batchCount, int batchStride) const {                               \
255     DCHECK(initialized_);                                                    \
256     return GtsvStridedBatchImpl(SPARSE_FN(gtsvStridedBatch, sparse_prefix),  \
257                                 *gpusparse_handle_, m, dl, d, du, x,         \
258                                 batchCount, batchStride);                    \
259   }
260 
261 TF_CALL_LAPACK_TYPES(GTSV_STRIDED_BATCH_INSTANCE);
262 
263 template <typename Scalar, typename SparseFn>
Gtsv2Impl(SparseFn op,cusparseHandle_t cusparse_handle,int m,int n,const Scalar * dl,const Scalar * d,const Scalar * du,Scalar * B,int ldb,void * pBuffer)264 static inline Status Gtsv2Impl(SparseFn op, cusparseHandle_t cusparse_handle,
265                                int m, int n, const Scalar* dl, const Scalar* d,
266                                const Scalar* du, Scalar* B, int ldb,
267                                void* pBuffer) {
268   TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, n, AsCudaComplex(dl),
269                                   AsCudaComplex(d), AsCudaComplex(du),
270                                   AsCudaComplex(B), ldb, pBuffer));
271   return Status::OK();
272 }
273 
274 #define GTSV2_INSTANCE(Scalar, sparse_prefix)                                \
275   template <>                                                                \
276   Status GpuSparse::Gtsv2<Scalar>(int m, int n, const Scalar* dl,            \
277                                   const Scalar* d, const Scalar* du,         \
278                                   Scalar* B, int ldb, void* pBuffer) const { \
279     DCHECK(initialized_);                                                    \
280     return Gtsv2Impl(SPARSE_FN(gtsv2, sparse_prefix), *gpusparse_handle_, m, \
281                      n, dl, d, du, B, ldb, pBuffer);                         \
282   }
283 
284 TF_CALL_LAPACK_TYPES(GTSV2_INSTANCE);
285 
286 #define GTSV2_NO_PIVOT_INSTANCE(Scalar, sparse_prefix)                      \
287   template <>                                                               \
288   Status GpuSparse::Gtsv2NoPivot<Scalar>(                                   \
289       int m, int n, const Scalar* dl, const Scalar* d, const Scalar* du,    \
290       Scalar* B, int ldb, void* pBuffer) const {                            \
291     DCHECK(initialized_);                                                   \
292     return Gtsv2Impl(SPARSE_FN(gtsv2_nopivot, sparse_prefix),               \
293                      *gpusparse_handle_, m, n, dl, d, du, B, ldb, pBuffer); \
294   }
295 
296 TF_CALL_LAPACK_TYPES(GTSV2_NO_PIVOT_INSTANCE);
297 
298 template <typename Scalar, typename SparseFn>
Gtsv2BufferSizeExtImpl(SparseFn op,cusparseHandle_t cusparse_handle,int m,int n,const Scalar * dl,const Scalar * d,const Scalar * du,const Scalar * B,int ldb,size_t * bufferSizeInBytes)299 static inline Status Gtsv2BufferSizeExtImpl(SparseFn op,
300                                             cusparseHandle_t cusparse_handle,
301                                             int m, int n, const Scalar* dl,
302                                             const Scalar* d, const Scalar* du,
303                                             const Scalar* B, int ldb,
304                                             size_t* bufferSizeInBytes) {
305   TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, n, AsCudaComplex(dl),
306                                   AsCudaComplex(d), AsCudaComplex(du),
307                                   AsCudaComplex(B), ldb, bufferSizeInBytes));
308   return Status::OK();
309 }
310 
311 #define GTSV2_BUFFER_SIZE_INSTANCE(Scalar, sparse_prefix)                     \
312   template <>                                                                 \
313   Status GpuSparse::Gtsv2BufferSizeExt<Scalar>(                               \
314       int m, int n, const Scalar* dl, const Scalar* d, const Scalar* du,      \
315       const Scalar* B, int ldb, size_t* bufferSizeInBytes) const {            \
316     DCHECK(initialized_);                                                     \
317     return Gtsv2BufferSizeExtImpl(                                            \
318         SPARSE_FN(gtsv2_bufferSizeExt, sparse_prefix), *gpusparse_handle_, m, \
319         n, dl, d, du, B, ldb, bufferSizeInBytes);                             \
320   }
321 
322 TF_CALL_LAPACK_TYPES(GTSV2_BUFFER_SIZE_INSTANCE);
323 
324 #define GTSV2_NO_PIVOT_BUFFER_SIZE_INSTANCE(Scalar, sparse_prefix)       \
325   template <>                                                            \
326   Status GpuSparse::Gtsv2NoPivotBufferSizeExt<Scalar>(                   \
327       int m, int n, const Scalar* dl, const Scalar* d, const Scalar* du, \
328       const Scalar* B, int ldb, size_t* bufferSizeInBytes) const {       \
329     DCHECK(initialized_);                                                \
330     return Gtsv2BufferSizeExtImpl(                                       \
331         SPARSE_FN(gtsv2_nopivot_bufferSizeExt, sparse_prefix),           \
332         *gpusparse_handle_, m, n, dl, d, du, B, ldb, bufferSizeInBytes); \
333   }
334 
335 TF_CALL_LAPACK_TYPES(GTSV2_NO_PIVOT_BUFFER_SIZE_INSTANCE);
336 
337 template <typename Scalar, typename SparseFn>
Gtsv2StridedBatchImpl(SparseFn op,cusparseHandle_t cusparse_handle,int m,const Scalar * dl,const Scalar * d,const Scalar * du,Scalar * x,int batchCount,int batchStride,void * pBuffer)338 static inline Status Gtsv2StridedBatchImpl(SparseFn op,
339                                            cusparseHandle_t cusparse_handle,
340                                            int m, const Scalar* dl,
341                                            const Scalar* d, const Scalar* du,
342                                            Scalar* x, int batchCount,
343                                            int batchStride, void* pBuffer) {
344   TF_RETURN_IF_GPUSPARSE_ERROR(op(
345       cusparse_handle, m, AsCudaComplex(dl), AsCudaComplex(d),
346       AsCudaComplex(du), AsCudaComplex(x), batchCount, batchStride, pBuffer));
347   return Status::OK();
348 }
349 
350 #define GTSV2_STRIDED_BATCH_INSTANCE(Scalar, sparse_prefix)                   \
351   template <>                                                                 \
352   Status GpuSparse::Gtsv2StridedBatch<Scalar>(                                \
353       int m, const Scalar* dl, const Scalar* d, const Scalar* du, Scalar* x,  \
354       int batchCount, int batchStride, void* pBuffer) const {                 \
355     DCHECK(initialized_);                                                     \
356     return Gtsv2StridedBatchImpl(SPARSE_FN(gtsv2StridedBatch, sparse_prefix), \
357                                  *gpusparse_handle_, m, dl, d, du, x,         \
358                                  batchCount, batchStride, pBuffer);           \
359   }
360 
361 TF_CALL_LAPACK_TYPES(GTSV2_STRIDED_BATCH_INSTANCE);
362 
363 template <typename Scalar, typename SparseFn>
Gtsv2StridedBatchBufferSizeImpl(SparseFn op,cusparseHandle_t cusparse_handle,int m,const Scalar * dl,const Scalar * d,const Scalar * du,const Scalar * x,int batchCount,int batchStride,size_t * bufferSizeInBytes)364 static inline Status Gtsv2StridedBatchBufferSizeImpl(
365     SparseFn op, cusparseHandle_t cusparse_handle, int m, const Scalar* dl,
366     const Scalar* d, const Scalar* du, const Scalar* x, int batchCount,
367     int batchStride, size_t* bufferSizeInBytes) {
368   TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, AsCudaComplex(dl),
369                                   AsCudaComplex(d), AsCudaComplex(du),
370                                   AsCudaComplex(x), batchCount, batchStride,
371                                   bufferSizeInBytes));
372   return Status::OK();
373 }
374 
375 #define GTSV2_STRIDED_BATCH_BUFFER_SIZE_INSTANCE(Scalar, sparse_prefix) \
376   template <>                                                           \
377   Status GpuSparse::Gtsv2StridedBatchBufferSizeExt<Scalar>(             \
378       int m, const Scalar* dl, const Scalar* d, const Scalar* du,       \
379       const Scalar* x, int batchCount, int batchStride,                 \
380       size_t* bufferSizeInBytes) const {                                \
381     DCHECK(initialized_);                                               \
382     return Gtsv2StridedBatchBufferSizeImpl(                             \
383         SPARSE_FN(gtsv2StridedBatch_bufferSizeExt, sparse_prefix),      \
384         *gpusparse_handle_, m, dl, d, du, x, batchCount, batchStride,   \
385         bufferSizeInBytes);                                             \
386   }
387 
388 TF_CALL_LAPACK_TYPES(GTSV2_STRIDED_BATCH_BUFFER_SIZE_INSTANCE);
389 
Coo2csr(const int * cooRowInd,int nnz,int m,int * csrRowPtr) const390 Status GpuSparse::Coo2csr(const int* cooRowInd, int nnz, int m,
391                           int* csrRowPtr) const {
392   // cusparseStatus_t CUSPARSEAPI cusparseXcoo2csr(cusparseHandle_t handle,
393   //                                               const int *cooRowInd,
394   //                                               int nnz,
395   //                                               int m,
396   //                                               int *csrSortedRowPtr,
397   //                                               cusparseIndexBase_t
398   //                                               idxBase);
399   DCHECK(initialized_);
400   TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcoo2csr(*gpusparse_handle_, cooRowInd,
401                                                 nnz, m, csrRowPtr,
402                                                 CUSPARSE_INDEX_BASE_ZERO));
403   return Status::OK();
404 }
405 
Csr2coo(const int * csrRowPtr,int nnz,int m,int * cooRowInd) const406 Status GpuSparse::Csr2coo(const int* csrRowPtr, int nnz, int m,
407                           int* cooRowInd) const {
408   // cusparseStatus_t CUSPARSEAPI cusparseXcsr2coo(cusparseHandle_t handle,
409   //                                               const int *csrRowPtr,
410   //                                               int nnz,
411   //                                               int m,
412   //                                               int *cooRowInd,
413   //                                               cusparseIndexBase_t
414   //                                               idxBase);
415   DCHECK(initialized_);
416   TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsr2coo(*gpusparse_handle_, csrRowPtr,
417                                                 nnz, m, cooRowInd,
418                                                 CUSPARSE_INDEX_BASE_ZERO));
419   return Status::OK();
420 }
421 
CsrgeamNnz(int m,int n,const cusparseMatDescr_t descrA,int nnzA,const int * csrSortedRowPtrA,const int * csrSortedColIndA,const cusparseMatDescr_t descrB,int nnzB,const int * csrSortedRowPtrB,const int * csrSortedColIndB,const cusparseMatDescr_t descrC,int * csrSortedRowPtrC,int * nnzTotalDevHostPtr)422 Status GpuSparse::CsrgeamNnz(int m, int n, const cusparseMatDescr_t descrA,
423                              int nnzA, const int* csrSortedRowPtrA,
424                              const int* csrSortedColIndA,
425                              const cusparseMatDescr_t descrB, int nnzB,
426                              const int* csrSortedRowPtrB,
427                              const int* csrSortedColIndB,
428                              const cusparseMatDescr_t descrC,
429                              int* csrSortedRowPtrC, int* nnzTotalDevHostPtr) {
430   DCHECK(initialized_);
431   DCHECK(nnzTotalDevHostPtr != nullptr);
432   TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsrgeamNnz(
433       *gpusparse_handle_, m, n, descrA, nnzA, csrSortedRowPtrA,
434       csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, csrSortedColIndB,
435       descrC, csrSortedRowPtrC, nnzTotalDevHostPtr));
436   return Status::OK();
437 }
438 
439 template <typename Scalar, typename SparseFnT>
CsrmmImpl(SparseFnT op,OpKernelContext * context,cusparseHandle_t cusparse_handle,cusparseOperation_t transA,cusparseOperation_t transB,int m,int n,int k,int nnz,const Scalar * alpha_host,const cusparseMatDescr_t descrA,const Scalar * csrSortedValA,const int * csrSortedRowPtrA,const int * csrSortedColIndA,const Scalar * B,int ldb,const Scalar * beta_host,Scalar * C,int ldc)440 static inline Status CsrmmImpl(
441     SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
442     cusparseOperation_t transA, cusparseOperation_t transB, int m, int n, int k,
443     int nnz, const Scalar* alpha_host, const cusparseMatDescr_t descrA,
444     const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
445     const int* csrSortedColIndA, const Scalar* B, int ldb,
446     const Scalar* beta_host, Scalar* C, int ldc) {
447   // cusparseStatus_t CUSPARSEAPI cusparseScsrmm2(
448   //     cusparseHandle_t handle, cusparseOperation_t transA,
449   //     cusparseOperation_t transB, int m, int n, int k, int nnz,
450   //     const float* alpha, const cusparseMatDescr_t descrA,
451   //     const float* csrSortedValA, const int* csrSortedRowPtrA,
452   //     const int* csrSortedColIndA, const float* B, int ldb, const float*
453   //     beta, float* C, int ldc);
454   TF_RETURN_IF_GPUSPARSE_ERROR(op(
455       cusparse_handle, transA, transB, m, n, k, nnz, AsCudaComplex(alpha_host),
456       descrA, AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
457       AsCudaComplex(B), ldb, AsCudaComplex(beta_host), AsCudaComplex(C), ldc));
458   return Status::OK();
459 }
460 
461 #define CSRMM_INSTANCE(Scalar, sparse_prefix)                                \
462   template <>                                                                \
463   Status GpuSparse::Csrmm<Scalar>(                                           \
464       cusparseOperation_t transA, cusparseOperation_t transB, int m, int n,  \
465       int k, int nnz, const Scalar* alpha_host,                              \
466       const cusparseMatDescr_t descrA, const Scalar* csrSortedValA,          \
467       const int* csrSortedRowPtrA, const int* csrSortedColIndA,              \
468       const Scalar* B, int ldb, const Scalar* beta_host, Scalar* C, int ldc) \
469       const {                                                                \
470     DCHECK(initialized_);                                                    \
471     return CsrmmImpl(SPARSE_FN(csrmm2, sparse_prefix), context_,             \
472                      *gpusparse_handle_, transA, transB, m, n, k, nnz,       \
473                      alpha_host, descrA, csrSortedValA, csrSortedRowPtrA,    \
474                      csrSortedColIndA, B, ldb, beta_host, C, ldc);           \
475   }
476 
477 TF_CALL_LAPACK_TYPES(CSRMM_INSTANCE);
478 
479 template <typename Scalar, typename SparseFnT>
CsrmvImpl(SparseFnT op,OpKernelContext * context,cusparseHandle_t cusparse_handle,cusparseOperation_t transA,int m,int n,int nnz,const Scalar * alpha_host,const cusparseMatDescr_t descrA,const Scalar * csrSortedValA,const int * csrSortedRowPtrA,const int * csrSortedColIndA,const Scalar * x,const Scalar * beta_host,Scalar * y)480 static inline Status CsrmvImpl(
481     SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
482     cusparseOperation_t transA, int m, int n, int nnz, const Scalar* alpha_host,
483     const cusparseMatDescr_t descrA, const Scalar* csrSortedValA,
484     const int* csrSortedRowPtrA, const int* csrSortedColIndA, const Scalar* x,
485     const Scalar* beta_host, Scalar* y) {
486   TF_RETURN_IF_GPUSPARSE_ERROR(
487       op(cusparse_handle, transA, m, n, nnz, AsCudaComplex(alpha_host), descrA,
488          AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
489          AsCudaComplex(x), AsCudaComplex(beta_host), AsCudaComplex(y)));
490   return Status::OK();
491 }
492 
493 // TODO(ebrevdo,rmlarsen): Use csrmv_mp for all cases when available in CUDA 9.
494 #define CSRMV_INSTANCE(Scalar, sparse_prefix)                                \
495   template <>                                                                \
496   Status GpuSparse::Csrmv<Scalar>(                                           \
497       cusparseOperation_t transA, int m, int n, int nnz,                     \
498       const Scalar* alpha_host, const cusparseMatDescr_t descrA,             \
499       const Scalar* csrSortedValA, const int* csrSortedRowPtrA,              \
500       const int* csrSortedColIndA, const Scalar* x, const Scalar* beta_host, \
501       Scalar* y) const {                                                     \
502     DCHECK(initialized_);                                                    \
503     if (transA == CUSPARSE_OPERATION_NON_TRANSPOSE) {                        \
504       return CsrmvImpl(SPARSE_FN(csrmv_mp, sparse_prefix), context_,         \
505                        *gpusparse_handle_, transA, m, n, nnz, alpha_host,    \
506                        descrA, csrSortedValA, csrSortedRowPtrA,              \
507                        csrSortedColIndA, x, beta_host, y);                   \
508     } else {                                                                 \
509       return CsrmvImpl(SPARSE_FN(csrmv, sparse_prefix), context_,            \
510                        *gpusparse_handle_, transA, m, n, nnz, alpha_host,    \
511                        descrA, csrSortedValA, csrSortedRowPtrA,              \
512                        csrSortedColIndA, x, beta_host, y);                   \
513     }                                                                        \
514   }
515 
516 TF_CALL_LAPACK_TYPES(CSRMV_INSTANCE);
517 
518 template <typename Scalar, typename SparseFnT>
CsrgeamImpl(SparseFnT op,OpKernelContext * context,cusparseHandle_t cusparse_handle,int m,int n,const Scalar * alpha,const cusparseMatDescr_t descrA,int nnzA,const Scalar * csrSortedValA,const int * csrSortedRowPtrA,const int * csrSortedColIndA,const Scalar * beta,const cusparseMatDescr_t descrB,int nnzB,const Scalar * csrSortedValB,const int * csrSortedRowPtrB,const int * csrSortedColIndB,const cusparseMatDescr_t descrC,Scalar * csrSortedValC,int * csrSortedRowPtrC,int * csrSortedColIndC)519 static inline Status CsrgeamImpl(
520     SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
521     int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA,
522     int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
523     const int* csrSortedColIndA, const Scalar* beta,
524     const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB,
525     const int* csrSortedRowPtrB, const int* csrSortedColIndB,
526     const cusparseMatDescr_t descrC, Scalar* csrSortedValC,
527     int* csrSortedRowPtrC, int* csrSortedColIndC) {
528   TF_RETURN_IF_GPUSPARSE_ERROR(
529       op(cusparse_handle, m, n, AsCudaComplex(alpha), descrA, nnzA,
530          AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
531          AsCudaComplex(beta), descrB, nnzB, AsCudaComplex(csrSortedValB),
532          csrSortedRowPtrB, csrSortedColIndB, descrC,
533          AsCudaComplex(csrSortedValC), csrSortedRowPtrC, csrSortedColIndC));
534   return Status::OK();
535 }
536 
537 #define CSRGEAM_INSTANCE(Scalar, sparse_prefix)                               \
538   template <>                                                                 \
539   Status GpuSparse::Csrgeam<Scalar>(                                          \
540       int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA,     \
541       int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA,     \
542       const int* csrSortedColIndA, const Scalar* beta,                        \
543       const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB, \
544       const int* csrSortedRowPtrB, const int* csrSortedColIndB,               \
545       const cusparseMatDescr_t descrC, Scalar* csrSortedValC,                 \
546       int* csrSortedRowPtrC, int* csrSortedColIndC) {                         \
547     DCHECK(initialized_);                                                     \
548     return CsrgeamImpl(SPARSE_FN(csrgeam, sparse_prefix), context_,           \
549                        *gpusparse_handle_, m, n, alpha, descrA, nnzA,         \
550                        csrSortedValA, csrSortedRowPtrA, csrSortedColIndA,     \
551                        beta, descrB, nnzB, csrSortedValB, csrSortedRowPtrB,   \
552                        csrSortedColIndB, descrC, csrSortedValC,               \
553                        csrSortedRowPtrC, csrSortedColIndC);                   \
554   }
555 
556 TF_CALL_LAPACK_TYPES(CSRGEAM_INSTANCE);
557 
CsrgemmNnz(cusparseOperation_t transA,cusparseOperation_t transB,int m,int k,int n,const cusparseMatDescr_t descrA,int nnzA,const int * csrSortedRowPtrA,const int * csrSortedColIndA,const cusparseMatDescr_t descrB,int nnzB,const int * csrSortedRowPtrB,const int * csrSortedColIndB,const cusparseMatDescr_t descrC,int * csrSortedRowPtrC,int * nnzTotalDevHostPtr)558 Status GpuSparse::CsrgemmNnz(
559     cusparseOperation_t transA, cusparseOperation_t transB, int m, int k, int n,
560     const cusparseMatDescr_t descrA, int nnzA, const int* csrSortedRowPtrA,
561     const int* csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB,
562     const int* csrSortedRowPtrB, const int* csrSortedColIndB,
563     const cusparseMatDescr_t descrC, int* csrSortedRowPtrC,
564     int* nnzTotalDevHostPtr) {
565   DCHECK(initialized_);
566   DCHECK(nnzTotalDevHostPtr != nullptr);
567   TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsrgemmNnz(
568       *gpusparse_handle_, transA, transB, m, k, n, descrA, nnzA,
569       csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB,
570       csrSortedColIndB, descrC, csrSortedRowPtrC, nnzTotalDevHostPtr));
571   return Status::OK();
572 }
573 
574 template <typename Scalar, typename SparseFnT>
CsrgemmImpl(SparseFnT op,OpKernelContext * context,cusparseHandle_t cusparse_handle,cusparseOperation_t transA,cusparseOperation_t transB,int m,int k,int n,const cusparseMatDescr_t descrA,int nnzA,const Scalar * csrSortedValA,const int * csrSortedRowPtrA,const int * csrSortedColIndA,const cusparseMatDescr_t descrB,int nnzB,const Scalar * csrSortedValB,const int * csrSortedRowPtrB,const int * csrSortedColIndB,const cusparseMatDescr_t descrC,Scalar * csrSortedValC,int * csrSortedRowPtrC,int * csrSortedColIndC)575 static inline Status CsrgemmImpl(
576     SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
577     cusparseOperation_t transA, cusparseOperation_t transB, int m, int k, int n,
578     const cusparseMatDescr_t descrA, int nnzA, const Scalar* csrSortedValA,
579     const int* csrSortedRowPtrA, const int* csrSortedColIndA,
580     const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB,
581     const int* csrSortedRowPtrB, const int* csrSortedColIndB,
582     const cusparseMatDescr_t descrC, Scalar* csrSortedValC,
583     int* csrSortedRowPtrC, int* csrSortedColIndC) {
584   TF_RETURN_IF_GPUSPARSE_ERROR(
585       op(cusparse_handle, transA, transB, m, k, n, descrA, nnzA,
586          AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
587          descrB, nnzB, AsCudaComplex(csrSortedValB), csrSortedRowPtrB,
588          csrSortedColIndB, descrC, AsCudaComplex(csrSortedValC),
589          csrSortedRowPtrC, csrSortedColIndC));
590   return Status::OK();
591 }
592 
593 #define CSRGEMM_INSTANCE(Scalar, sparse_prefix)                               \
594   template <>                                                                 \
595   Status GpuSparse::Csrgemm<Scalar>(                                          \
596       cusparseOperation_t transA, cusparseOperation_t transB, int m, int k,   \
597       int n, const cusparseMatDescr_t descrA, int nnzA,                       \
598       const Scalar* csrSortedValA, const int* csrSortedRowPtrA,               \
599       const int* csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB, \
600       const Scalar* csrSortedValB, const int* csrSortedRowPtrB,               \
601       const int* csrSortedColIndB, const cusparseMatDescr_t descrC,           \
602       Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC) {  \
603     DCHECK(initialized_);                                                     \
604     return CsrgemmImpl(SPARSE_FN(csrgemm, sparse_prefix), context_,           \
605                        *gpusparse_handle_, transA, transB, m, k, n, descrA,   \
606                        nnzA, csrSortedValA, csrSortedRowPtrA,                 \
607                        csrSortedColIndA, descrB, nnzB, csrSortedValB,         \
608                        csrSortedRowPtrB, csrSortedColIndB, descrC,            \
609                        csrSortedValC, csrSortedRowPtrC, csrSortedColIndC);    \
610   }
611 
612 TF_CALL_LAPACK_TYPES(CSRGEMM_INSTANCE);
613 
614 template <typename Scalar, typename BufferSizeFnT, typename SparseFnT>
Csru2csrImpl(SparseFnT op,BufferSizeFnT buffer_size_op,OpKernelContext * context,cusparseHandle_t cusparse_handle,int m,int n,int nnz,const cusparseMatDescr_t descrA,Scalar * csrVal,const int * csrRowPtr,int * csrColInd)615 static inline Status Csru2csrImpl(SparseFnT op, BufferSizeFnT buffer_size_op,
616                                   OpKernelContext* context,
617                                   cusparseHandle_t cusparse_handle, int m,
618                                   int n, int nnz,
619                                   const cusparseMatDescr_t descrA,
620                                   Scalar* csrVal, const int* csrRowPtr,
621                                   int* csrColInd) {
622   GpuSparseCsrSortingConversionInfo info;
623   TF_RETURN_IF_ERROR(info.Initialize());
624 
625   size_t pBufferSizeInBytes = 0;
626 
627   TF_RETURN_IF_GPUSPARSE_ERROR(
628       buffer_size_op(cusparse_handle, m, n, nnz, AsCudaComplex(csrVal),
629                      csrRowPtr, csrColInd, info.info(), &pBufferSizeInBytes));
630 
631   Tensor pBuffer_t;
632   TF_RETURN_IF_ERROR(context->allocate_temp(
633       DT_INT8, TensorShape({static_cast<int64>(pBufferSizeInBytes)}),
634       &pBuffer_t));
635   auto pBuffer = pBuffer_t.flat<int8>();
636   DCHECK(pBuffer.data() != nullptr);
637 
638   TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, n, nnz, descrA,
639                                   AsCudaComplex(csrVal), csrRowPtr, csrColInd,
640                                   info.info(), pBuffer.data()));
641 
642   return Status::OK();
643 }
644 
645 #define CSRU2CSR_INSTANCE(Scalar, sparse_prefix)                              \
646   template <>                                                                 \
647   Status GpuSparse::Csru2csr<Scalar>(                                         \
648       int m, int n, int nnz, const cusparseMatDescr_t descrA, Scalar* csrVal, \
649       const int* csrRowPtr, int* csrColInd) {                                 \
650     DCHECK(initialized_);                                                     \
651     return Csru2csrImpl(SPARSE_FN(csru2csr, sparse_prefix),                   \
652                         BUFSIZE_FN(csru2csr, sparse_prefix), context_,        \
653                         *gpusparse_handle_, m, n, nnz, descrA, csrVal,        \
654                         csrRowPtr, csrColInd);                                \
655   }
656 
657 TF_CALL_LAPACK_TYPES(CSRU2CSR_INSTANCE);
658 
659 template <typename Scalar, typename SparseFnT>
Csr2cscImpl(SparseFnT op,OpKernelContext * context,cusparseHandle_t cusparse_handle,int m,int n,int nnz,const Scalar * csrVal,const int * csrRowPtr,const int * csrColInd,Scalar * cscVal,int * cscRowInd,int * cscColPtr,const cusparseAction_t copyValues)660 static inline Status Csr2cscImpl(SparseFnT op, OpKernelContext* context,
661                                  cusparseHandle_t cusparse_handle, int m, int n,
662                                  int nnz, const Scalar* csrVal,
663                                  const int* csrRowPtr, const int* csrColInd,
664                                  Scalar* cscVal, int* cscRowInd, int* cscColPtr,
665                                  const cusparseAction_t copyValues) {
666   TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, n, nnz,
667                                   AsCudaComplex(csrVal), csrRowPtr, csrColInd,
668                                   AsCudaComplex(cscVal), cscRowInd, cscColPtr,
669                                   copyValues, CUSPARSE_INDEX_BASE_ZERO));
670   return Status::OK();
671 }
672 
673 #define CSR2CSC_INSTANCE(Scalar, sparse_prefix)                              \
674   template <>                                                                \
675   Status GpuSparse::Csr2csc<Scalar>(                                         \
676       int m, int n, int nnz, const Scalar* csrVal, const int* csrRowPtr,     \
677       const int* csrColInd, Scalar* cscVal, int* cscRowInd, int* cscColPtr,  \
678       const cusparseAction_t copyValues) {                                   \
679     DCHECK(initialized_);                                                    \
680     return Csr2cscImpl(SPARSE_FN(csr2csc, sparse_prefix), context_,          \
681                        *gpusparse_handle_, m, n, nnz, csrVal, csrRowPtr,     \
682                        csrColInd, cscVal, cscRowInd, cscColPtr, copyValues); \
683   }
684 
685 TF_CALL_LAPACK_TYPES(CSR2CSC_INSTANCE);
686 
687 }  // namespace tensorflow
688 
689 #endif  // GOOGLE_CUDA
690