1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifdef GOOGLE_CUDA
17
18 #include "tensorflow/core/kernels/cuda_sparse.h"
19
20 #include <complex>
21 #include <memory>
22 #include <unordered_map>
23 #include <utility>
24 #include <vector>
25
26 #include "third_party/gpus/cuda/include/cusparse.h"
27 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/kernels/cuda_solvers.h"
31 #include "tensorflow/core/lib/core/blocking_counter.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/lib/core/stringpiece.h"
34 #include "tensorflow/core/lib/gtl/inlined_vector.h"
35 #include "tensorflow/core/lib/strings/strcat.h"
36 #include "tensorflow/core/platform/macros.h"
37 #include "tensorflow/core/platform/mutex.h"
38 #include "tensorflow/core/platform/stream_executor.h"
39 #include "tensorflow/core/platform/types.h"
40
41 // TODO(rmlarsen,penporn): Investigate using newer kernels in CUDA 10.1+.
42
43 namespace tensorflow {
44 namespace {
45
46 // Type traits to get CUDA complex types from std::complex<>.
47 // TODO: reuse with cuda_solvers
48 template <typename T>
49 struct CudaComplexT {
50 typedef T type;
51 };
52 template <>
53 struct CudaComplexT<std::complex<float>> {
54 typedef cuComplex type;
55 };
56 template <>
57 struct CudaComplexT<std::complex<double>> {
58 typedef cuDoubleComplex type;
59 };
60 // Converts pointers of std::complex<> to pointers of
61 // cuComplex/cuDoubleComplex. No type conversion for non-complex types.
62 template <typename T>
AsCudaComplex(const T * p)63 inline const typename CudaComplexT<T>::type* AsCudaComplex(const T* p) {
64 return reinterpret_cast<const typename CudaComplexT<T>::type*>(p);
65 }
66 template <typename T>
AsCudaComplex(T * p)67 inline typename CudaComplexT<T>::type* AsCudaComplex(T* p) {
68 return reinterpret_cast<typename CudaComplexT<T>::type*>(p);
69 }
70
71 // A set of initialized handles to the underlying Cuda libraries used by
72 // GpuSparse. We maintain one such set of handles per unique stream.
73 class CudaSparseHandles {
74 public:
CudaSparseHandles(cudaStream_t stream)75 explicit CudaSparseHandles(cudaStream_t stream)
76 : initialized_(false), stream_(stream) {}
77
CudaSparseHandles(CudaSparseHandles && rhs)78 CudaSparseHandles(CudaSparseHandles&& rhs)
79 : initialized_(rhs.initialized_),
80 stream_(std::move(rhs.stream_)),
81 cusparse_handle_(rhs.cusparse_handle_) {
82 rhs.initialized_ = false;
83 }
84
operator =(CudaSparseHandles && rhs)85 CudaSparseHandles& operator=(CudaSparseHandles&& rhs) {
86 if (this == &rhs) return *this;
87 Release();
88 stream_ = std::move(rhs.stream_);
89 cusparse_handle_ = std::move(rhs.cusparse_handle_);
90 initialized_ = rhs.initialized_;
91 rhs.initialized_ = false;
92 return *this;
93 }
94
~CudaSparseHandles()95 ~CudaSparseHandles() { Release(); }
96
Initialize()97 Status Initialize() {
98 if (initialized_) return Status::OK();
99 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreate(&cusparse_handle_));
100 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSetStream(cusparse_handle_, stream_));
101 initialized_ = true;
102 return Status::OK();
103 }
104
handle()105 cusparseHandle_t& handle() {
106 DCHECK(initialized_);
107 return cusparse_handle_;
108 }
109
handle() const110 const cusparseHandle_t& handle() const {
111 DCHECK(initialized_);
112 return cusparse_handle_;
113 }
114
115 private:
Release()116 void Release() {
117 if (initialized_) {
118 // This should never return anything other than success
119 auto err = cusparseDestroy(cusparse_handle_);
120 DCHECK(err == CUSPARSE_STATUS_SUCCESS)
121 << "Failed to destroy cuSparse instance.";
122 initialized_ = false;
123 }
124 }
125 bool initialized_;
126 cudaStream_t stream_;
127 cusparseHandle_t cusparse_handle_;
128
129 TF_DISALLOW_COPY_AND_ASSIGN(CudaSparseHandles);
130 };
131
132 // TODO(ebrevdo): Replace global mutex guarding CudaSparseHandles
133 // lookup with one of:
134 // 1. Adding the handle to the CudaStream structure; do the lookup there.
135 // 2. Add a thread-local cusparse, set it to the current stream
136 // upon each call.
137 // #1 seems like the cleanest option but will need to wait until this
138 // is moved into TF core.
139 static mutex handle_map_mutex(LINKER_INITIALIZED);
140
141 using HandleMap = std::unordered_map<cudaStream_t, CudaSparseHandles>;
142
143 // Returns a singleton map used for storing initialized handles for each unique
144 // cuda stream.
GetHandleMapSingleton()145 HandleMap* GetHandleMapSingleton() {
146 static HandleMap* cm = new HandleMap;
147 return cm;
148 }
149
150 } // namespace
151
GpuSparse(OpKernelContext * context)152 GpuSparse::GpuSparse(OpKernelContext* context)
153 : initialized_(false), context_(context) {
154 auto cuda_stream_ptr =
155 reinterpret_cast<const cudaStream_t*>(context->op_device_context()
156 ->stream()
157 ->implementation()
158 ->GpuStreamMemberHack());
159 DCHECK(cuda_stream_ptr);
160 gpu_stream_ = *cuda_stream_ptr;
161 }
162
Initialize()163 Status GpuSparse::Initialize() {
164 HandleMap* handle_map = GetHandleMapSingleton();
165 DCHECK(handle_map);
166 mutex_lock lock(handle_map_mutex);
167 auto it = handle_map->find(gpu_stream_);
168 if (it == handle_map->end()) {
169 LOG(INFO) << "Creating CudaSparse handles for stream " << gpu_stream_;
170 // Previously unseen Cuda stream. Initialize a set of Cuda sparse library
171 // handles for it.
172 CudaSparseHandles new_handles(gpu_stream_);
173 TF_RETURN_IF_ERROR(new_handles.Initialize());
174 it = handle_map->insert(std::make_pair(gpu_stream_, std::move(new_handles)))
175 .first;
176 }
177 gpusparse_handle_ = &it->second.handle();
178 initialized_ = true;
179 return Status::OK();
180 }
181
182 // Macro that specializes a sparse method for all 4 standard
183 // numeric types.
184 // TODO: reuse with cuda_solvers
185 #define TF_CALL_LAPACK_TYPES(m) \
186 m(float, S) m(double, D) m(std::complex<float>, C) m(std::complex<double>, Z)
187
188 // Macros to construct cusparse method names.
189 #define SPARSE_FN(method, sparse_prefix) cusparse##sparse_prefix##method
190 #define SPARSE_NAME(method, sparse_prefix) "cusparse" #sparse_prefix #method
191 #define BUFSIZE_FN(method, sparse_prefix) \
192 cusparse##sparse_prefix##method##_bufferSizeExt
193
194 //=============================================================================
195 // Wrappers of cuSparse computational methods begin here.
196 //
197 // WARNING to implementers: The function signatures listed in the online docs
198 // are sometimes inaccurate, e.g., are missing 'const' on pointers
199 // to immutable arguments, while the actual headers have them as expected.
200 // Check the actual declarations in the cusparse.h header file.
201 //=============================================================================
202
203 template <typename Scalar, typename SparseFn>
GtsvImpl(SparseFn op,cusparseHandle_t cusparse_handle,int m,int n,const Scalar * dl,const Scalar * d,const Scalar * du,Scalar * B,int ldb)204 static inline Status GtsvImpl(SparseFn op, cusparseHandle_t cusparse_handle,
205 int m, int n, const Scalar* dl, const Scalar* d,
206 const Scalar* du, Scalar* B, int ldb) {
207 TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, n, AsCudaComplex(dl),
208 AsCudaComplex(d), AsCudaComplex(du),
209 AsCudaComplex(B), ldb));
210 return Status::OK();
211 }
212
213 #define GTSV_INSTANCE(Scalar, sparse_prefix) \
214 template <> \
215 Status GpuSparse::Gtsv<Scalar>(int m, int n, const Scalar* dl, \
216 const Scalar* d, const Scalar* du, Scalar* B, \
217 int ldb) const { \
218 DCHECK(initialized_); \
219 return GtsvImpl(SPARSE_FN(gtsv, sparse_prefix), *gpusparse_handle_, m, n, \
220 dl, d, du, B, ldb); \
221 }
222
223 TF_CALL_LAPACK_TYPES(GTSV_INSTANCE);
224
225 #define GTSV_NO_PIVOT_INSTANCE(Scalar, sparse_prefix) \
226 template <> \
227 Status GpuSparse::GtsvNoPivot<Scalar>(int m, int n, const Scalar* dl, \
228 const Scalar* d, const Scalar* du, \
229 Scalar* B, int ldb) const { \
230 DCHECK(initialized_); \
231 return GtsvImpl(SPARSE_FN(gtsv_nopivot, sparse_prefix), \
232 *gpusparse_handle_, m, n, dl, d, du, B, ldb); \
233 }
234
235 TF_CALL_LAPACK_TYPES(GTSV_NO_PIVOT_INSTANCE);
236
237 template <typename Scalar, typename SparseFn>
GtsvStridedBatchImpl(SparseFn op,cusparseHandle_t cusparse_handle,int m,const Scalar * dl,const Scalar * d,const Scalar * du,Scalar * x,int batchCount,int batchStride)238 static inline Status GtsvStridedBatchImpl(SparseFn op,
239 cusparseHandle_t cusparse_handle,
240 int m, const Scalar* dl,
241 const Scalar* d, const Scalar* du,
242 Scalar* x, int batchCount,
243 int batchStride) {
244 TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, AsCudaComplex(dl),
245 AsCudaComplex(d), AsCudaComplex(du),
246 AsCudaComplex(x), batchCount, batchStride));
247 return Status::OK();
248 }
249
250 #define GTSV_STRIDED_BATCH_INSTANCE(Scalar, sparse_prefix) \
251 template <> \
252 Status GpuSparse::GtsvStridedBatch<Scalar>( \
253 int m, const Scalar* dl, const Scalar* d, const Scalar* du, Scalar* x, \
254 int batchCount, int batchStride) const { \
255 DCHECK(initialized_); \
256 return GtsvStridedBatchImpl(SPARSE_FN(gtsvStridedBatch, sparse_prefix), \
257 *gpusparse_handle_, m, dl, d, du, x, \
258 batchCount, batchStride); \
259 }
260
261 TF_CALL_LAPACK_TYPES(GTSV_STRIDED_BATCH_INSTANCE);
262
263 template <typename Scalar, typename SparseFn>
Gtsv2Impl(SparseFn op,cusparseHandle_t cusparse_handle,int m,int n,const Scalar * dl,const Scalar * d,const Scalar * du,Scalar * B,int ldb,void * pBuffer)264 static inline Status Gtsv2Impl(SparseFn op, cusparseHandle_t cusparse_handle,
265 int m, int n, const Scalar* dl, const Scalar* d,
266 const Scalar* du, Scalar* B, int ldb,
267 void* pBuffer) {
268 TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, n, AsCudaComplex(dl),
269 AsCudaComplex(d), AsCudaComplex(du),
270 AsCudaComplex(B), ldb, pBuffer));
271 return Status::OK();
272 }
273
274 #define GTSV2_INSTANCE(Scalar, sparse_prefix) \
275 template <> \
276 Status GpuSparse::Gtsv2<Scalar>(int m, int n, const Scalar* dl, \
277 const Scalar* d, const Scalar* du, \
278 Scalar* B, int ldb, void* pBuffer) const { \
279 DCHECK(initialized_); \
280 return Gtsv2Impl(SPARSE_FN(gtsv2, sparse_prefix), *gpusparse_handle_, m, \
281 n, dl, d, du, B, ldb, pBuffer); \
282 }
283
284 TF_CALL_LAPACK_TYPES(GTSV2_INSTANCE);
285
286 #define GTSV2_NO_PIVOT_INSTANCE(Scalar, sparse_prefix) \
287 template <> \
288 Status GpuSparse::Gtsv2NoPivot<Scalar>( \
289 int m, int n, const Scalar* dl, const Scalar* d, const Scalar* du, \
290 Scalar* B, int ldb, void* pBuffer) const { \
291 DCHECK(initialized_); \
292 return Gtsv2Impl(SPARSE_FN(gtsv2_nopivot, sparse_prefix), \
293 *gpusparse_handle_, m, n, dl, d, du, B, ldb, pBuffer); \
294 }
295
296 TF_CALL_LAPACK_TYPES(GTSV2_NO_PIVOT_INSTANCE);
297
298 template <typename Scalar, typename SparseFn>
Gtsv2BufferSizeExtImpl(SparseFn op,cusparseHandle_t cusparse_handle,int m,int n,const Scalar * dl,const Scalar * d,const Scalar * du,const Scalar * B,int ldb,size_t * bufferSizeInBytes)299 static inline Status Gtsv2BufferSizeExtImpl(SparseFn op,
300 cusparseHandle_t cusparse_handle,
301 int m, int n, const Scalar* dl,
302 const Scalar* d, const Scalar* du,
303 const Scalar* B, int ldb,
304 size_t* bufferSizeInBytes) {
305 TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, n, AsCudaComplex(dl),
306 AsCudaComplex(d), AsCudaComplex(du),
307 AsCudaComplex(B), ldb, bufferSizeInBytes));
308 return Status::OK();
309 }
310
311 #define GTSV2_BUFFER_SIZE_INSTANCE(Scalar, sparse_prefix) \
312 template <> \
313 Status GpuSparse::Gtsv2BufferSizeExt<Scalar>( \
314 int m, int n, const Scalar* dl, const Scalar* d, const Scalar* du, \
315 const Scalar* B, int ldb, size_t* bufferSizeInBytes) const { \
316 DCHECK(initialized_); \
317 return Gtsv2BufferSizeExtImpl( \
318 SPARSE_FN(gtsv2_bufferSizeExt, sparse_prefix), *gpusparse_handle_, m, \
319 n, dl, d, du, B, ldb, bufferSizeInBytes); \
320 }
321
322 TF_CALL_LAPACK_TYPES(GTSV2_BUFFER_SIZE_INSTANCE);
323
324 #define GTSV2_NO_PIVOT_BUFFER_SIZE_INSTANCE(Scalar, sparse_prefix) \
325 template <> \
326 Status GpuSparse::Gtsv2NoPivotBufferSizeExt<Scalar>( \
327 int m, int n, const Scalar* dl, const Scalar* d, const Scalar* du, \
328 const Scalar* B, int ldb, size_t* bufferSizeInBytes) const { \
329 DCHECK(initialized_); \
330 return Gtsv2BufferSizeExtImpl( \
331 SPARSE_FN(gtsv2_nopivot_bufferSizeExt, sparse_prefix), \
332 *gpusparse_handle_, m, n, dl, d, du, B, ldb, bufferSizeInBytes); \
333 }
334
335 TF_CALL_LAPACK_TYPES(GTSV2_NO_PIVOT_BUFFER_SIZE_INSTANCE);
336
337 template <typename Scalar, typename SparseFn>
Gtsv2StridedBatchImpl(SparseFn op,cusparseHandle_t cusparse_handle,int m,const Scalar * dl,const Scalar * d,const Scalar * du,Scalar * x,int batchCount,int batchStride,void * pBuffer)338 static inline Status Gtsv2StridedBatchImpl(SparseFn op,
339 cusparseHandle_t cusparse_handle,
340 int m, const Scalar* dl,
341 const Scalar* d, const Scalar* du,
342 Scalar* x, int batchCount,
343 int batchStride, void* pBuffer) {
344 TF_RETURN_IF_GPUSPARSE_ERROR(op(
345 cusparse_handle, m, AsCudaComplex(dl), AsCudaComplex(d),
346 AsCudaComplex(du), AsCudaComplex(x), batchCount, batchStride, pBuffer));
347 return Status::OK();
348 }
349
350 #define GTSV2_STRIDED_BATCH_INSTANCE(Scalar, sparse_prefix) \
351 template <> \
352 Status GpuSparse::Gtsv2StridedBatch<Scalar>( \
353 int m, const Scalar* dl, const Scalar* d, const Scalar* du, Scalar* x, \
354 int batchCount, int batchStride, void* pBuffer) const { \
355 DCHECK(initialized_); \
356 return Gtsv2StridedBatchImpl(SPARSE_FN(gtsv2StridedBatch, sparse_prefix), \
357 *gpusparse_handle_, m, dl, d, du, x, \
358 batchCount, batchStride, pBuffer); \
359 }
360
361 TF_CALL_LAPACK_TYPES(GTSV2_STRIDED_BATCH_INSTANCE);
362
363 template <typename Scalar, typename SparseFn>
Gtsv2StridedBatchBufferSizeImpl(SparseFn op,cusparseHandle_t cusparse_handle,int m,const Scalar * dl,const Scalar * d,const Scalar * du,const Scalar * x,int batchCount,int batchStride,size_t * bufferSizeInBytes)364 static inline Status Gtsv2StridedBatchBufferSizeImpl(
365 SparseFn op, cusparseHandle_t cusparse_handle, int m, const Scalar* dl,
366 const Scalar* d, const Scalar* du, const Scalar* x, int batchCount,
367 int batchStride, size_t* bufferSizeInBytes) {
368 TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, AsCudaComplex(dl),
369 AsCudaComplex(d), AsCudaComplex(du),
370 AsCudaComplex(x), batchCount, batchStride,
371 bufferSizeInBytes));
372 return Status::OK();
373 }
374
375 #define GTSV2_STRIDED_BATCH_BUFFER_SIZE_INSTANCE(Scalar, sparse_prefix) \
376 template <> \
377 Status GpuSparse::Gtsv2StridedBatchBufferSizeExt<Scalar>( \
378 int m, const Scalar* dl, const Scalar* d, const Scalar* du, \
379 const Scalar* x, int batchCount, int batchStride, \
380 size_t* bufferSizeInBytes) const { \
381 DCHECK(initialized_); \
382 return Gtsv2StridedBatchBufferSizeImpl( \
383 SPARSE_FN(gtsv2StridedBatch_bufferSizeExt, sparse_prefix), \
384 *gpusparse_handle_, m, dl, d, du, x, batchCount, batchStride, \
385 bufferSizeInBytes); \
386 }
387
388 TF_CALL_LAPACK_TYPES(GTSV2_STRIDED_BATCH_BUFFER_SIZE_INSTANCE);
389
Coo2csr(const int * cooRowInd,int nnz,int m,int * csrRowPtr) const390 Status GpuSparse::Coo2csr(const int* cooRowInd, int nnz, int m,
391 int* csrRowPtr) const {
392 // cusparseStatus_t CUSPARSEAPI cusparseXcoo2csr(cusparseHandle_t handle,
393 // const int *cooRowInd,
394 // int nnz,
395 // int m,
396 // int *csrSortedRowPtr,
397 // cusparseIndexBase_t
398 // idxBase);
399 DCHECK(initialized_);
400 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcoo2csr(*gpusparse_handle_, cooRowInd,
401 nnz, m, csrRowPtr,
402 CUSPARSE_INDEX_BASE_ZERO));
403 return Status::OK();
404 }
405
Csr2coo(const int * csrRowPtr,int nnz,int m,int * cooRowInd) const406 Status GpuSparse::Csr2coo(const int* csrRowPtr, int nnz, int m,
407 int* cooRowInd) const {
408 // cusparseStatus_t CUSPARSEAPI cusparseXcsr2coo(cusparseHandle_t handle,
409 // const int *csrRowPtr,
410 // int nnz,
411 // int m,
412 // int *cooRowInd,
413 // cusparseIndexBase_t
414 // idxBase);
415 DCHECK(initialized_);
416 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsr2coo(*gpusparse_handle_, csrRowPtr,
417 nnz, m, cooRowInd,
418 CUSPARSE_INDEX_BASE_ZERO));
419 return Status::OK();
420 }
421
CsrgeamNnz(int m,int n,const cusparseMatDescr_t descrA,int nnzA,const int * csrSortedRowPtrA,const int * csrSortedColIndA,const cusparseMatDescr_t descrB,int nnzB,const int * csrSortedRowPtrB,const int * csrSortedColIndB,const cusparseMatDescr_t descrC,int * csrSortedRowPtrC,int * nnzTotalDevHostPtr)422 Status GpuSparse::CsrgeamNnz(int m, int n, const cusparseMatDescr_t descrA,
423 int nnzA, const int* csrSortedRowPtrA,
424 const int* csrSortedColIndA,
425 const cusparseMatDescr_t descrB, int nnzB,
426 const int* csrSortedRowPtrB,
427 const int* csrSortedColIndB,
428 const cusparseMatDescr_t descrC,
429 int* csrSortedRowPtrC, int* nnzTotalDevHostPtr) {
430 DCHECK(initialized_);
431 DCHECK(nnzTotalDevHostPtr != nullptr);
432 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsrgeamNnz(
433 *gpusparse_handle_, m, n, descrA, nnzA, csrSortedRowPtrA,
434 csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, csrSortedColIndB,
435 descrC, csrSortedRowPtrC, nnzTotalDevHostPtr));
436 return Status::OK();
437 }
438
439 template <typename Scalar, typename SparseFnT>
CsrmmImpl(SparseFnT op,OpKernelContext * context,cusparseHandle_t cusparse_handle,cusparseOperation_t transA,cusparseOperation_t transB,int m,int n,int k,int nnz,const Scalar * alpha_host,const cusparseMatDescr_t descrA,const Scalar * csrSortedValA,const int * csrSortedRowPtrA,const int * csrSortedColIndA,const Scalar * B,int ldb,const Scalar * beta_host,Scalar * C,int ldc)440 static inline Status CsrmmImpl(
441 SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
442 cusparseOperation_t transA, cusparseOperation_t transB, int m, int n, int k,
443 int nnz, const Scalar* alpha_host, const cusparseMatDescr_t descrA,
444 const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
445 const int* csrSortedColIndA, const Scalar* B, int ldb,
446 const Scalar* beta_host, Scalar* C, int ldc) {
447 // cusparseStatus_t CUSPARSEAPI cusparseScsrmm2(
448 // cusparseHandle_t handle, cusparseOperation_t transA,
449 // cusparseOperation_t transB, int m, int n, int k, int nnz,
450 // const float* alpha, const cusparseMatDescr_t descrA,
451 // const float* csrSortedValA, const int* csrSortedRowPtrA,
452 // const int* csrSortedColIndA, const float* B, int ldb, const float*
453 // beta, float* C, int ldc);
454 TF_RETURN_IF_GPUSPARSE_ERROR(op(
455 cusparse_handle, transA, transB, m, n, k, nnz, AsCudaComplex(alpha_host),
456 descrA, AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
457 AsCudaComplex(B), ldb, AsCudaComplex(beta_host), AsCudaComplex(C), ldc));
458 return Status::OK();
459 }
460
461 #define CSRMM_INSTANCE(Scalar, sparse_prefix) \
462 template <> \
463 Status GpuSparse::Csrmm<Scalar>( \
464 cusparseOperation_t transA, cusparseOperation_t transB, int m, int n, \
465 int k, int nnz, const Scalar* alpha_host, \
466 const cusparseMatDescr_t descrA, const Scalar* csrSortedValA, \
467 const int* csrSortedRowPtrA, const int* csrSortedColIndA, \
468 const Scalar* B, int ldb, const Scalar* beta_host, Scalar* C, int ldc) \
469 const { \
470 DCHECK(initialized_); \
471 return CsrmmImpl(SPARSE_FN(csrmm2, sparse_prefix), context_, \
472 *gpusparse_handle_, transA, transB, m, n, k, nnz, \
473 alpha_host, descrA, csrSortedValA, csrSortedRowPtrA, \
474 csrSortedColIndA, B, ldb, beta_host, C, ldc); \
475 }
476
477 TF_CALL_LAPACK_TYPES(CSRMM_INSTANCE);
478
479 template <typename Scalar, typename SparseFnT>
CsrmvImpl(SparseFnT op,OpKernelContext * context,cusparseHandle_t cusparse_handle,cusparseOperation_t transA,int m,int n,int nnz,const Scalar * alpha_host,const cusparseMatDescr_t descrA,const Scalar * csrSortedValA,const int * csrSortedRowPtrA,const int * csrSortedColIndA,const Scalar * x,const Scalar * beta_host,Scalar * y)480 static inline Status CsrmvImpl(
481 SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
482 cusparseOperation_t transA, int m, int n, int nnz, const Scalar* alpha_host,
483 const cusparseMatDescr_t descrA, const Scalar* csrSortedValA,
484 const int* csrSortedRowPtrA, const int* csrSortedColIndA, const Scalar* x,
485 const Scalar* beta_host, Scalar* y) {
486 TF_RETURN_IF_GPUSPARSE_ERROR(
487 op(cusparse_handle, transA, m, n, nnz, AsCudaComplex(alpha_host), descrA,
488 AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
489 AsCudaComplex(x), AsCudaComplex(beta_host), AsCudaComplex(y)));
490 return Status::OK();
491 }
492
493 // TODO(ebrevdo,rmlarsen): Use csrmv_mp for all cases when available in CUDA 9.
494 #define CSRMV_INSTANCE(Scalar, sparse_prefix) \
495 template <> \
496 Status GpuSparse::Csrmv<Scalar>( \
497 cusparseOperation_t transA, int m, int n, int nnz, \
498 const Scalar* alpha_host, const cusparseMatDescr_t descrA, \
499 const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
500 const int* csrSortedColIndA, const Scalar* x, const Scalar* beta_host, \
501 Scalar* y) const { \
502 DCHECK(initialized_); \
503 if (transA == CUSPARSE_OPERATION_NON_TRANSPOSE) { \
504 return CsrmvImpl(SPARSE_FN(csrmv_mp, sparse_prefix), context_, \
505 *gpusparse_handle_, transA, m, n, nnz, alpha_host, \
506 descrA, csrSortedValA, csrSortedRowPtrA, \
507 csrSortedColIndA, x, beta_host, y); \
508 } else { \
509 return CsrmvImpl(SPARSE_FN(csrmv, sparse_prefix), context_, \
510 *gpusparse_handle_, transA, m, n, nnz, alpha_host, \
511 descrA, csrSortedValA, csrSortedRowPtrA, \
512 csrSortedColIndA, x, beta_host, y); \
513 } \
514 }
515
516 TF_CALL_LAPACK_TYPES(CSRMV_INSTANCE);
517
518 template <typename Scalar, typename SparseFnT>
CsrgeamImpl(SparseFnT op,OpKernelContext * context,cusparseHandle_t cusparse_handle,int m,int n,const Scalar * alpha,const cusparseMatDescr_t descrA,int nnzA,const Scalar * csrSortedValA,const int * csrSortedRowPtrA,const int * csrSortedColIndA,const Scalar * beta,const cusparseMatDescr_t descrB,int nnzB,const Scalar * csrSortedValB,const int * csrSortedRowPtrB,const int * csrSortedColIndB,const cusparseMatDescr_t descrC,Scalar * csrSortedValC,int * csrSortedRowPtrC,int * csrSortedColIndC)519 static inline Status CsrgeamImpl(
520 SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
521 int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA,
522 int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
523 const int* csrSortedColIndA, const Scalar* beta,
524 const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB,
525 const int* csrSortedRowPtrB, const int* csrSortedColIndB,
526 const cusparseMatDescr_t descrC, Scalar* csrSortedValC,
527 int* csrSortedRowPtrC, int* csrSortedColIndC) {
528 TF_RETURN_IF_GPUSPARSE_ERROR(
529 op(cusparse_handle, m, n, AsCudaComplex(alpha), descrA, nnzA,
530 AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
531 AsCudaComplex(beta), descrB, nnzB, AsCudaComplex(csrSortedValB),
532 csrSortedRowPtrB, csrSortedColIndB, descrC,
533 AsCudaComplex(csrSortedValC), csrSortedRowPtrC, csrSortedColIndC));
534 return Status::OK();
535 }
536
537 #define CSRGEAM_INSTANCE(Scalar, sparse_prefix) \
538 template <> \
539 Status GpuSparse::Csrgeam<Scalar>( \
540 int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA, \
541 int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
542 const int* csrSortedColIndA, const Scalar* beta, \
543 const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB, \
544 const int* csrSortedRowPtrB, const int* csrSortedColIndB, \
545 const cusparseMatDescr_t descrC, Scalar* csrSortedValC, \
546 int* csrSortedRowPtrC, int* csrSortedColIndC) { \
547 DCHECK(initialized_); \
548 return CsrgeamImpl(SPARSE_FN(csrgeam, sparse_prefix), context_, \
549 *gpusparse_handle_, m, n, alpha, descrA, nnzA, \
550 csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, \
551 beta, descrB, nnzB, csrSortedValB, csrSortedRowPtrB, \
552 csrSortedColIndB, descrC, csrSortedValC, \
553 csrSortedRowPtrC, csrSortedColIndC); \
554 }
555
556 TF_CALL_LAPACK_TYPES(CSRGEAM_INSTANCE);
557
CsrgemmNnz(cusparseOperation_t transA,cusparseOperation_t transB,int m,int k,int n,const cusparseMatDescr_t descrA,int nnzA,const int * csrSortedRowPtrA,const int * csrSortedColIndA,const cusparseMatDescr_t descrB,int nnzB,const int * csrSortedRowPtrB,const int * csrSortedColIndB,const cusparseMatDescr_t descrC,int * csrSortedRowPtrC,int * nnzTotalDevHostPtr)558 Status GpuSparse::CsrgemmNnz(
559 cusparseOperation_t transA, cusparseOperation_t transB, int m, int k, int n,
560 const cusparseMatDescr_t descrA, int nnzA, const int* csrSortedRowPtrA,
561 const int* csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB,
562 const int* csrSortedRowPtrB, const int* csrSortedColIndB,
563 const cusparseMatDescr_t descrC, int* csrSortedRowPtrC,
564 int* nnzTotalDevHostPtr) {
565 DCHECK(initialized_);
566 DCHECK(nnzTotalDevHostPtr != nullptr);
567 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsrgemmNnz(
568 *gpusparse_handle_, transA, transB, m, k, n, descrA, nnzA,
569 csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB,
570 csrSortedColIndB, descrC, csrSortedRowPtrC, nnzTotalDevHostPtr));
571 return Status::OK();
572 }
573
574 template <typename Scalar, typename SparseFnT>
CsrgemmImpl(SparseFnT op,OpKernelContext * context,cusparseHandle_t cusparse_handle,cusparseOperation_t transA,cusparseOperation_t transB,int m,int k,int n,const cusparseMatDescr_t descrA,int nnzA,const Scalar * csrSortedValA,const int * csrSortedRowPtrA,const int * csrSortedColIndA,const cusparseMatDescr_t descrB,int nnzB,const Scalar * csrSortedValB,const int * csrSortedRowPtrB,const int * csrSortedColIndB,const cusparseMatDescr_t descrC,Scalar * csrSortedValC,int * csrSortedRowPtrC,int * csrSortedColIndC)575 static inline Status CsrgemmImpl(
576 SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
577 cusparseOperation_t transA, cusparseOperation_t transB, int m, int k, int n,
578 const cusparseMatDescr_t descrA, int nnzA, const Scalar* csrSortedValA,
579 const int* csrSortedRowPtrA, const int* csrSortedColIndA,
580 const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB,
581 const int* csrSortedRowPtrB, const int* csrSortedColIndB,
582 const cusparseMatDescr_t descrC, Scalar* csrSortedValC,
583 int* csrSortedRowPtrC, int* csrSortedColIndC) {
584 TF_RETURN_IF_GPUSPARSE_ERROR(
585 op(cusparse_handle, transA, transB, m, k, n, descrA, nnzA,
586 AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
587 descrB, nnzB, AsCudaComplex(csrSortedValB), csrSortedRowPtrB,
588 csrSortedColIndB, descrC, AsCudaComplex(csrSortedValC),
589 csrSortedRowPtrC, csrSortedColIndC));
590 return Status::OK();
591 }
592
593 #define CSRGEMM_INSTANCE(Scalar, sparse_prefix) \
594 template <> \
595 Status GpuSparse::Csrgemm<Scalar>( \
596 cusparseOperation_t transA, cusparseOperation_t transB, int m, int k, \
597 int n, const cusparseMatDescr_t descrA, int nnzA, \
598 const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
599 const int* csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB, \
600 const Scalar* csrSortedValB, const int* csrSortedRowPtrB, \
601 const int* csrSortedColIndB, const cusparseMatDescr_t descrC, \
602 Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC) { \
603 DCHECK(initialized_); \
604 return CsrgemmImpl(SPARSE_FN(csrgemm, sparse_prefix), context_, \
605 *gpusparse_handle_, transA, transB, m, k, n, descrA, \
606 nnzA, csrSortedValA, csrSortedRowPtrA, \
607 csrSortedColIndA, descrB, nnzB, csrSortedValB, \
608 csrSortedRowPtrB, csrSortedColIndB, descrC, \
609 csrSortedValC, csrSortedRowPtrC, csrSortedColIndC); \
610 }
611
612 TF_CALL_LAPACK_TYPES(CSRGEMM_INSTANCE);
613
614 template <typename Scalar, typename BufferSizeFnT, typename SparseFnT>
Csru2csrImpl(SparseFnT op,BufferSizeFnT buffer_size_op,OpKernelContext * context,cusparseHandle_t cusparse_handle,int m,int n,int nnz,const cusparseMatDescr_t descrA,Scalar * csrVal,const int * csrRowPtr,int * csrColInd)615 static inline Status Csru2csrImpl(SparseFnT op, BufferSizeFnT buffer_size_op,
616 OpKernelContext* context,
617 cusparseHandle_t cusparse_handle, int m,
618 int n, int nnz,
619 const cusparseMatDescr_t descrA,
620 Scalar* csrVal, const int* csrRowPtr,
621 int* csrColInd) {
622 GpuSparseCsrSortingConversionInfo info;
623 TF_RETURN_IF_ERROR(info.Initialize());
624
625 size_t pBufferSizeInBytes = 0;
626
627 TF_RETURN_IF_GPUSPARSE_ERROR(
628 buffer_size_op(cusparse_handle, m, n, nnz, AsCudaComplex(csrVal),
629 csrRowPtr, csrColInd, info.info(), &pBufferSizeInBytes));
630
631 Tensor pBuffer_t;
632 TF_RETURN_IF_ERROR(context->allocate_temp(
633 DT_INT8, TensorShape({static_cast<int64>(pBufferSizeInBytes)}),
634 &pBuffer_t));
635 auto pBuffer = pBuffer_t.flat<int8>();
636 DCHECK(pBuffer.data() != nullptr);
637
638 TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, n, nnz, descrA,
639 AsCudaComplex(csrVal), csrRowPtr, csrColInd,
640 info.info(), pBuffer.data()));
641
642 return Status::OK();
643 }
644
645 #define CSRU2CSR_INSTANCE(Scalar, sparse_prefix) \
646 template <> \
647 Status GpuSparse::Csru2csr<Scalar>( \
648 int m, int n, int nnz, const cusparseMatDescr_t descrA, Scalar* csrVal, \
649 const int* csrRowPtr, int* csrColInd) { \
650 DCHECK(initialized_); \
651 return Csru2csrImpl(SPARSE_FN(csru2csr, sparse_prefix), \
652 BUFSIZE_FN(csru2csr, sparse_prefix), context_, \
653 *gpusparse_handle_, m, n, nnz, descrA, csrVal, \
654 csrRowPtr, csrColInd); \
655 }
656
657 TF_CALL_LAPACK_TYPES(CSRU2CSR_INSTANCE);
658
659 template <typename Scalar, typename SparseFnT>
Csr2cscImpl(SparseFnT op,OpKernelContext * context,cusparseHandle_t cusparse_handle,int m,int n,int nnz,const Scalar * csrVal,const int * csrRowPtr,const int * csrColInd,Scalar * cscVal,int * cscRowInd,int * cscColPtr,const cusparseAction_t copyValues)660 static inline Status Csr2cscImpl(SparseFnT op, OpKernelContext* context,
661 cusparseHandle_t cusparse_handle, int m, int n,
662 int nnz, const Scalar* csrVal,
663 const int* csrRowPtr, const int* csrColInd,
664 Scalar* cscVal, int* cscRowInd, int* cscColPtr,
665 const cusparseAction_t copyValues) {
666 TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, n, nnz,
667 AsCudaComplex(csrVal), csrRowPtr, csrColInd,
668 AsCudaComplex(cscVal), cscRowInd, cscColPtr,
669 copyValues, CUSPARSE_INDEX_BASE_ZERO));
670 return Status::OK();
671 }
672
673 #define CSR2CSC_INSTANCE(Scalar, sparse_prefix) \
674 template <> \
675 Status GpuSparse::Csr2csc<Scalar>( \
676 int m, int n, int nnz, const Scalar* csrVal, const int* csrRowPtr, \
677 const int* csrColInd, Scalar* cscVal, int* cscRowInd, int* cscColPtr, \
678 const cusparseAction_t copyValues) { \
679 DCHECK(initialized_); \
680 return Csr2cscImpl(SPARSE_FN(csr2csc, sparse_prefix), context_, \
681 *gpusparse_handle_, m, n, nnz, csrVal, csrRowPtr, \
682 csrColInd, cscVal, cscRowInd, cscColPtr, copyValues); \
683 }
684
685 TF_CALL_LAPACK_TYPES(CSR2CSC_INSTANCE);
686
687 } // namespace tensorflow
688
689 #endif // GOOGLE_CUDA
690