• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
19 #define EIGEN_USE_GPU
20 #endif
21 
22 #include "third_party/eigen3/Eigen/Core"
23 #include "third_party/eigen3/Eigen/SparseCore"
24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/tensor_types.h"
28 #include "tensorflow/core/framework/variant_op_registry.h"
29 #include "tensorflow/core/kernels/cwise_ops_common.h"
30 #include "tensorflow/core/kernels/dense_update_functor.h"
31 #include "tensorflow/core/kernels/fill_functor.h"
32 #include "tensorflow/core/kernels/sparse/kernels.h"
33 #include "tensorflow/core/kernels/sparse/sparse_matrix.h"
34 #include "tensorflow/core/kernels/sparse/transpose_op.h"
35 #include "tensorflow/core/kernels/transpose_functor.h"
36 #include "tensorflow/core/lib/gtl/inlined_vector.h"
37 #include "tensorflow/core/platform/threadpool.h"
38 
39 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
40 #include "tensorflow/core/util/cuda_solvers.h"
41 #include "tensorflow/core/util/cuda_sparse.h"
42 #endif
43 
44 namespace tensorflow {
45 
46 // TODO(anudhyan): These constants may be tuned based on the performance of
47 // 'benchmark_sparse_matrix_mat_vec_mul'. We would like to find constants
48 // which work across hardware platforms for typical matrix sizes. It should be
49 // possible to observe at least 30-50% improvement as we increase the number
50 // of threads by 1. If not, then it may we worth increasing kMaxShards and
51 // kNumShardsPerThread. However, once we have too many shards, latency may be
52 // dominated by per-shard overhead.
53 //
54 // Maximum number of shards into which to divide the computation for each CSR
55 // Sparse Matrix instance.
56 static constexpr int32_t kMaxShards = 20;
57 // Number of shards allocated to each thread.
58 static constexpr int32_t kNumShardsPerThread = 3;
59 
60 typedef Eigen::ThreadPoolDevice CPUDevice;
61 typedef Eigen::GpuDevice GPUDevice;
62 
63 // Abstract OpKernel to compute sparse-dense matrix multiplication.
64 //
65 // Implements a kernel which, given a SparseMatrix `a` and dense Tensor `b`,
66 // computes a dense Tensor `c` satisfying `c = a * b` where * denotes matrix
67 // multiplication.
68 //
69 // The boolean attributes `transpose_a` and `adjoint_a` will transpose or
70 // adjoint `a` before multiplication, respectively. At most one of these
71 // attributes must be set to True. Corresponding attributes will transpose or
72 // adjoint `b` or the output (after multiplication).
73 //
74 // The rank of both `a` and `b` must be equal and their shapes must be
75 // compatible for matrix multiplication. Otherwise, InvalidArgument runtime
76 // errors will be thrown. Only rank 2 or rank 3 inputs are supported.
77 //
78 template <typename Device, typename T>
79 class CSRMatMulOp : public OpKernel {
80  public:
CSRMatMulOp(OpKernelConstruction * c)81   explicit CSRMatMulOp(OpKernelConstruction* c) : OpKernel(c) {
82     OP_REQUIRES_OK(c, c->GetAttr("transpose_a", &transpose_a_));
83     OP_REQUIRES_OK(c, c->GetAttr("transpose_b", &transpose_b_));
84     bool adjoint_a;
85     OP_REQUIRES_OK(c, c->GetAttr("adjoint_a", &adjoint_a));
86     OP_REQUIRES(c, !(adjoint_a && transpose_a_),
87                 errors::InvalidArgument(
88                     "Only one of adjoint_a and transpose_a may be true."));
89     bool adjoint_b;
90     OP_REQUIRES_OK(c, c->GetAttr("adjoint_b", &adjoint_b));
91     OP_REQUIRES(c, !(adjoint_b && transpose_b_),
92                 errors::InvalidArgument(
93                     "Only one of adjoint_b and transpose_b may be true."));
94     OP_REQUIRES_OK(c, c->GetAttr("transpose_output", &transpose_output_));
95     OP_REQUIRES_OK(c, c->GetAttr("conjugate_output", &conjugate_output_));
96     conjugate_a_ = adjoint_a;
97     conjugate_b_ = adjoint_b;
98     transpose_a_ |= adjoint_a;
99     transpose_b_ |= adjoint_b;
100   }
101 
~CSRMatMulOp()102   ~CSRMatMulOp() override {}
103 
ValidateInputs(const CSRSparseMatrix & sparse_matrix_a,const Tensor & dense_tensor_b,int * rank,int64 * batch_size)104   Status ValidateInputs(const CSRSparseMatrix& sparse_matrix_a,
105                         const Tensor& dense_tensor_b, int* rank,
106                         int64* batch_size) {
107     if (sparse_matrix_a.dtype() != dense_tensor_b.dtype()) {
108       return errors::InvalidArgument(
109           "Input types don't match.  a.dtype == ",
110           DataTypeString(sparse_matrix_a.dtype()),
111           " vs. b.dtype == ", DataTypeString(dense_tensor_b.dtype()));
112     }
113     *rank = sparse_matrix_a.dims();
114     // TODO(ebrevdo): Add support for broadcasting matmul.
115     if (*rank != dense_tensor_b.dims()) {
116       return errors::InvalidArgument("Ranks of a and b must match, saw: ", rank,
117                                      " vs. ", dense_tensor_b.dims(), ".");
118     }
119     // A valid CSR SparseMatrix has rank 2 or rank 3.
120     *batch_size = (*rank == 2) ? 1 : dense_tensor_b.dim_size(0);
121     if (sparse_matrix_a.batch_size() != *batch_size) {
122       return errors::InvalidArgument("Batch sizes of a and b must match, saw: ",
123                                      sparse_matrix_a.batch_size(), " vs. ",
124                                      batch_size, ".");
125     }
126     const auto& a_dense_shape = sparse_matrix_a.dense_shape().vec<int64>();
127     const int64_t a_inner_dim =
128         a_dense_shape(this->transpose_a_ ? *rank - 2 : *rank - 1);
129     const int64_t b_inner_dim =
130         dense_tensor_b.dim_size(this->transpose_b_ ? *rank - 1 : *rank - 2);
131     if (a_inner_dim != b_inner_dim) {
132       return errors::InvalidArgument(
133           "Inner product dimensions of A and B do not agree.  Shapes are: ",
134           TensorShape(a_dense_shape), " vs. ",
135           dense_tensor_b.shape().DebugString());
136     }
137     return Status::OK();
138   }
139 
140  public:
141   bool transpose_a_;
142   bool transpose_b_;
143   bool conjugate_a_;
144   bool conjugate_b_;
145   bool transpose_output_;
146   bool conjugate_output_;
147 };
148 
149 // CPU Kernel to compute sparse-dense matrix multiplication.
150 //
151 // Uses Eigen SparseMatrix to compute the sparse-dense multiplication between
152 // a CSR SparseMatrix `a` and dense Tensor `b`. If intra-op parallelism is
153 // available, the implementation parallelizes the computation across each row
154 // of the sparse matrix.
155 template <typename T>
156 class CSRMatMulCPUOp : public CSRMatMulOp<CPUDevice, T> {
157   using SparseMatrix = Eigen::SparseMatrix<T, Eigen::RowMajor>;
158   using Matrix =
159       Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
160   using ConstMatrixMap = Eigen::Map<const Matrix>;
161   using MatrixMap = Eigen::Map<Matrix>;
162 
163  public:
CSRMatMulCPUOp(OpKernelConstruction * c)164   explicit CSRMatMulCPUOp(OpKernelConstruction* c)
165       : CSRMatMulOp<CPUDevice, T>(c) {}
166 
~CSRMatMulCPUOp()167   ~CSRMatMulCPUOp() override {}
168 
Compute(OpKernelContext * ctx)169   void Compute(OpKernelContext* ctx) final {
170     const CSRSparseMatrix* sparse_matrix_a;
171     OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &sparse_matrix_a));
172     const Tensor& matrix_b = ctx->input(1);
173 
174     int rank;
175     int64_t batch_size;
176     OP_REQUIRES_OK(ctx, this->ValidateInputs(*sparse_matrix_a, matrix_b, &rank,
177                                              &batch_size));
178 
179     const auto dense_shape = sparse_matrix_a->dense_shape().vec<int64>();
180     int64_t num_lhs_rows = dense_shape(rank - 2);
181     int64_t num_lhs_cols = dense_shape(rank - 1);
182     int64_t num_rhs_rows = matrix_b.dim_size(rank - 2);
183     int64_t num_rhs_cols = matrix_b.dim_size(rank - 1);
184 
185     if (this->transpose_a_) {
186       std::swap(num_lhs_rows, num_lhs_cols);
187     }
188 
189     // Possibly transpose the dense Tensor b.
190     const Tensor* rhs = &matrix_b;
191     Tensor b_transposed;
192     if (this->transpose_b_) {
193       OP_REQUIRES_OK(
194           ctx, TransposeAndConjugateTensor(ctx, matrix_b, this->conjugate_b_,
195                                            &b_transposed));
196       rhs = &b_transposed;
197       std::swap(num_rhs_rows, num_rhs_cols);
198     }
199 
200     // If we're transposing the output, then allocate a temporary buffer to
201     // store the output. Otherwise allocate the output directly.
202     Tensor* output = nullptr;
203     Tensor* matmul_result = nullptr;
204     Tensor output_transposed;
205     OP_REQUIRES_OK(
206         ctx, AllocateOutput(ctx, rank, batch_size, num_lhs_rows, num_rhs_cols,
207                             this->transpose_output_, &output,
208                             &output_transposed, &matmul_result));
209 
210     if (!this->transpose_a_) {
211       SparseDenseMatMulWithoutTransposedLHS(
212           ctx, batch_size, num_lhs_rows, *sparse_matrix_a, *rhs, matmul_result);
213     } else {  // transpose_a_ == true
214       SparseDenseMatMulWithTransposedLHS(ctx, batch_size, num_lhs_rows,
215                                          num_lhs_cols, *sparse_matrix_a, *rhs,
216                                          matmul_result);
217     }
218 
219     // Transpose (and conjugate) the output if necessary.
220     // Note that conjugate is only true if transpose is also true.
221     if (this->transpose_output_) {
222       OP_REQUIRES_OK(
223           ctx, TransposeAndConjugateAllocatedTensor(
224                    ctx, output_transposed, this->conjugate_output_, output));
225     } else if (this->conjugate_output_) {
226       functor::maybe_conj_inplace<CPUDevice, T>::run(
227           ctx->eigen_device<CPUDevice>(), output);
228     }
229   }
230 
231  private:
232   // Allocates the output with the appropriate shape. Additionally, if
233   // transpose_output is True, allocates a temporary buffer with the transposed
234   // output. 'matmul_result' points to either output or output_transposed, based
235   // on whether transpose_output is True.
AllocateOutput(OpKernelContext * ctx,const int32_t rank,const int64_t batch_size,const int64_t num_rows,const int64_t num_cols,const bool transpose_output,Tensor ** output,Tensor * output_transposed,Tensor ** matmul_result)236   Status AllocateOutput(OpKernelContext* ctx, const int32_t rank,
237                         const int64_t batch_size, const int64_t num_rows,
238                         const int64_t num_cols, const bool transpose_output,
239                         Tensor** output, Tensor* output_transposed,
240                         Tensor** matmul_result) {
241     TensorShape output_shape;
242     if (rank == 3) output_shape.AddDim(batch_size);
243 
244     if (!transpose_output) {
245       output_shape.AppendShape({num_rows, num_cols});
246       TF_RETURN_IF_ERROR(ctx->allocate_output(0, output_shape, output));
247       *matmul_result = *output;
248     } else {
249       TensorShape output_transposed_shape = output_shape;
250       output_transposed_shape.AppendShape({num_rows, num_cols});
251       output_shape.AppendShape({num_cols, num_rows});
252       TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value,
253                                             output_transposed_shape,
254                                             output_transposed));
255       TF_RETURN_IF_ERROR(ctx->allocate_output(0, output_shape, output));
256       *matmul_result = output_transposed;
257     }
258     return Status::OK();
259   }
260 
261   // Returns an Eigen::Ref expression of a sparse sub-matrix from the given
262   // contiguous segment of rows of the CSR Sparse Matrix.
GetSparseMatrixRef(const CSRSparseMatrix & csr_matrix,const int batch_index,const int64_t row_begin,const int64_t num_shard_rows,std::vector<int32> * row_ptrs)263   Eigen::Ref<const SparseMatrix> GetSparseMatrixRef(
264       const CSRSparseMatrix& csr_matrix, const int batch_index,
265       const int64_t row_begin, const int64_t num_shard_rows,
266       std::vector<int32>* row_ptrs) {
267     // Compute the row pointers of the sparse sub-matrix.
268     row_ptrs->resize(num_shard_rows + 1);
269     const int64_t row_offset =
270         csr_matrix.row_pointers_vec(batch_index)(row_begin);
271     for (int64_t row_idx = 0; row_idx <= num_shard_rows; ++row_idx) {
272       row_ptrs->at(row_idx) =
273           csr_matrix.row_pointers_vec(batch_index)(row_begin + row_idx) -
274           row_offset;
275     }
276     const int64_t num_cols =
277         csr_matrix.dense_shape().vec<int64>()(csr_matrix.dims() - 1);
278     return Eigen::Map<const SparseMatrix>(
279         num_shard_rows /* num_rows */, num_cols /* num_cols */,
280         row_ptrs->at(num_shard_rows) /* total_nnz */, row_ptrs->data(),
281         csr_matrix.col_indices_vec(batch_index).data() + row_offset,
282         csr_matrix.values_vec<T>(batch_index).data() + row_offset);
283   }
284 
285   // Sparse-Dense Matrix Multiplication between a CSRSparseMatrix (LHS) and a
286   // dense Tensor (RHS).
SparseDenseMatMulWithoutTransposedLHS(OpKernelContext * ctx,const int64_t batch_size,const int64_t num_lhs_rows,const CSRSparseMatrix & lhs,const Tensor & rhs,Tensor * output)287   void SparseDenseMatMulWithoutTransposedLHS(OpKernelContext* ctx,
288                                              const int64_t batch_size,
289                                              const int64_t num_lhs_rows,
290                                              const CSRSparseMatrix& lhs,
291                                              const Tensor& rhs,
292                                              Tensor* output) {
293     // Parallelize matrix multiplication across batch dimensions and across
294     // rows in each batch.
295     auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
296     const int32_t num_threads = worker_threads.num_threads;
297     const int64_t block_size =
298         num_lhs_rows / std::max(kMaxShards, kNumShardsPerThread * num_threads);
299     const int64_t num_rhs_rows = rhs.dim_size(rhs.dims() - 2);
300     const int64_t num_rhs_cols = rhs.dim_size(rhs.dims() - 1);
301     worker_threads.workers->ParallelFor(
302         batch_size * num_lhs_rows /* total */,
303         thread::ThreadPool::SchedulingParams(
304             thread::ThreadPool::SchedulingStrategy::
305                 kFixedBlockSize /* strategy */,
306             absl::nullopt /* cost_per_unit */, block_size),
307         [&](int64_t batch_and_row_begin, int64_t batch_and_row_end) {
308           HandleBatchAndRowRange(
309               num_lhs_rows, batch_and_row_begin, batch_and_row_end,
310               [&](int64_t batch_idx, int64_t row_begin, int64_t row_end) {
311                 const int64_t num_shard_rows = row_end - row_begin;
312 
313                 // Define an Eigen::SparseMatrix over the row range:
314                 // [row_begin, row_end) of the CSR SparseMatrix A.
315                 std::vector<int32> row_ptrs;
316                 auto sparse_matrix = GetSparseMatrixRef(
317                     lhs, batch_idx, row_begin, num_shard_rows, &row_ptrs);
318 
319                 // Map the corresponding rows of the rhs.
320                 ConstMatrixMap rhs_map(rhs.flat<T>().data() + batch_idx *
321                                                                   num_rhs_rows *
322                                                                   num_rhs_cols,
323                                        num_rhs_rows, num_rhs_cols);
324 
325                 // Write to the corresponding rows of the output matrix.
326                 MatrixMap output_map(
327                     output->flat<T>().data() +
328                         batch_idx * num_lhs_rows * num_rhs_cols +
329                         row_begin * num_rhs_cols,
330                     num_shard_rows, num_rhs_cols);
331                 output_map.noalias() = sparse_matrix * rhs_map;
332               });
333         });
334   }
335 
336   // Sparse-Dense Matrix Multiplication assuming the CSRSparseMatrix (LHS) is
337   // to be transposed before the operation.
SparseDenseMatMulWithTransposedLHS(OpKernelContext * ctx,const int64_t batch_size,const int64_t num_lhs_rows,const int64_t num_lhs_cols,const CSRSparseMatrix & lhs,const Tensor & rhs,Tensor * output)338   void SparseDenseMatMulWithTransposedLHS(OpKernelContext* ctx,
339                                           const int64_t batch_size,
340                                           const int64_t num_lhs_rows,
341                                           const int64_t num_lhs_cols,
342                                           const CSRSparseMatrix& lhs,
343                                           const Tensor& rhs, Tensor* output) {
344     auto device = ctx->eigen_device<CPUDevice>();
345     auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
346     const int32_t num_threads = worker_threads.num_threads;
347     const int64_t num_rhs_rows = rhs.dim_size(rhs.dims() - 2);
348     const int64_t num_rhs_cols = rhs.dim_size(rhs.dims() - 1);
349     // Usually, we want to avoid transposing the sparse matrix A since it may be
350     // an expensive operation. Instead, we use the identity (A^T B) = (B^T A)^T.
351     // We don't actually transpose B or the output because it is more convenient
352     // to have them in column major form.
353     //
354     // However, if A is hypersparse and B and C are huge, transposing A will be
355     // cheaper. In the future, we should have a cost model estimating the cost
356     // of transposing all matrices (A, B, C) to decide which variant to use.
357 
358     // Each thread writes to its own copy of the matrix product. These
359     // `num_threads` copies are summed together to obtain the final result.
360     Tensor matmul_result_buffer;
361     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
362                                            TensorShape({num_threads + 1,
363                                                         output->NumElements()}),
364                                            &matmul_result_buffer));
365     functor::SetZeroFunctor<CPUDevice, T> set_zero;
366     set_zero(device, matmul_result_buffer.flat<T>());
367 
368     // Parallelize matrix multiplication across batch dimensions and across
369     // columns of A^T in each batch. These correspond to rows of A.
370     const int64_t block_size =
371         num_lhs_cols / std::max(kMaxShards, kNumShardsPerThread * num_threads);
372     worker_threads.workers->ParallelForWithWorkerId(
373         batch_size * num_lhs_cols /* total */,
374         thread::ThreadPool::SchedulingParams(
375             thread::ThreadPool::SchedulingStrategy::
376                 kFixedBlockSize /* strategy */,
377             absl::nullopt /* cost_per_unit */, block_size),
378         [&](int64_t batch_and_row_begin, int64_t batch_and_row_end, int tid) {
379           HandleBatchAndRowRange(
380               num_lhs_cols, batch_and_row_begin, batch_and_row_end,
381               [&](int64_t batch_idx, int64_t row_begin, int64_t row_end) {
382                 const int64_t num_shard_rows = row_end - row_begin;
383 
384                 // Define a new sparse sub-matrix from the row range
385                 // [row_begin, row_end) of the sparse matrix A.
386                 std::vector<int32> row_ptrs;
387                 auto sparse_matrix = GetSparseMatrixRef(
388                     lhs, batch_idx, row_begin, num_shard_rows, &row_ptrs);
389 
390                 // Map the corresponding `num_shard_rows` columns of B^T.
391                 // This is the same as taking the `num_shard_rows` rows of B.
392                 ConstMatrixMap b_dense_map(
393                     rhs.flat<T>().data() +
394                         batch_idx * num_rhs_rows * num_rhs_cols +
395                         row_begin * num_rhs_cols,
396                     num_shard_rows, num_rhs_cols);
397 
398                 // Map to the corresponding rows of the output.
399                 MatrixMap output_map(
400                     matmul_result_buffer.flat<T>().data() +
401                         tid * batch_size * num_lhs_rows * num_rhs_cols +
402                         batch_idx * num_lhs_rows * num_rhs_cols,
403                     num_lhs_rows, num_rhs_cols);
404 
405                 // Compute the product C^T = B^T * A; restricted to the row
406                 // range in the current shard.
407                 if (this->conjugate_a_) {
408                   output_map.transpose().noalias() +=
409                       b_dense_map.transpose() * sparse_matrix.conjugate();
410                 } else {
411                   output_map.transpose().noalias() +=
412                       b_dense_map.transpose() * sparse_matrix;
413                 }
414               });
415         });
416 
417     // Sum across each thread's matmul result.
418     using Reducer = Eigen::internal::SumReducer<T>;
419     using Index = typename TTypes<T>::Tensor::Index;
420     output->flat<T>().device(device) = matmul_result_buffer.matrix<T>().reduce(
421         Eigen::array<Index, 1>({0}), Reducer());
422   }
423 
424   // Given a range [batch_and_row_begin, batch_and_row_end) which is a
425   // contiguous subset of [0, num_rows * batch_size), calls the function
426   // fn(batch_idx, row_begin, row_end) for each batch index
427   // and the row range [row_begin, row_end) contained in the batch.
HandleBatchAndRowRange(const int64_t num_rows,const int64_t batch_and_row_begin,const int64_t batch_and_row_end,const std::function<void (int64_t,int64_t,int64_t)> & fn)428   void HandleBatchAndRowRange(
429       const int64_t num_rows, const int64_t batch_and_row_begin,
430       const int64_t batch_and_row_end,
431       const std::function<void(int64_t, int64_t, int64_t)>& fn) {
432     // Obtain the batch indices overlapping with the current shard.
433     const int64_t batch_begin = batch_and_row_begin / num_rows;
434     const int64_t batch_end_inclusive = batch_and_row_end / num_rows;
435 
436     for (int64_t batch_idx = batch_begin; batch_idx <= batch_end_inclusive;
437          ++batch_idx) {
438       // Find the contiguous set of rows which are contained in this shard as
439       // well as the current batch. We intersect with interval [batch_idx *
440       // num_rows, (batch_idx + 1) * num_rows) which denotes the current batch.
441       const int64_t current_batch_row_begin =
442           std::max(batch_and_row_begin, batch_idx * num_rows);
443       const int64_t current_batch_row_end =
444           std::min(batch_and_row_end, (batch_idx + 1) * num_rows);
445 
446       const int64_t row_begin = current_batch_row_begin % num_rows;
447       const int64_t num_shard_rows =
448           current_batch_row_end - current_batch_row_begin;
449       // Edge case for when current_batch_row_end is the first index of a new
450       // row.
451       if (num_shard_rows == 0) continue;
452 
453       fn(batch_idx, row_begin, row_begin + num_shard_rows);
454     }
455   }
456 
457   // Transposes (and optionally, conjugates) a given Tensor. Also allocates the
458   // required memory for the output Tensor.
TransposeAndConjugateTensor(OpKernelContext * ctx,const Tensor & input,bool conjugate,Tensor * output)459   Status TransposeAndConjugateTensor(OpKernelContext* ctx, const Tensor& input,
460                                      bool conjugate, Tensor* output) {
461     TensorShape transposed_shape = input.shape();
462     transposed_shape.set_dim(input.dims() - 1,
463                              input.dim_size(input.dims() - 2));
464     transposed_shape.set_dim(input.dims() - 2,
465                              input.dim_size(input.dims() - 1));
466     TF_RETURN_IF_ERROR(
467         ctx->allocate_temp(DataTypeToEnum<T>::value, transposed_shape, output));
468     return TransposeAndConjugateAllocatedTensor(ctx, input, conjugate, output);
469   }
470 
471   // Transposes (and optionally, conjugates) a given Tensor. The output should
472   // be already allocated.
TransposeAndConjugateAllocatedTensor(OpKernelContext * ctx,const Tensor & input,bool conjugate,Tensor * output)473   Status TransposeAndConjugateAllocatedTensor(OpKernelContext* ctx,
474                                               const Tensor& input,
475                                               bool conjugate, Tensor* output) {
476     if (conjugate) {
477       TF_RETURN_IF_ERROR(DoConjugateMatrixTranspose(
478           ctx->eigen_device<CPUDevice>(), input, output));
479     } else {
480       TF_RETURN_IF_ERROR(
481           DoMatrixTranspose(ctx->eigen_device<CPUDevice>(), input, output));
482     }
483     return Status::OK();
484   }
485 };
486 
487 // GPU Kernel to compute sparse-dense matrix multiplication.
488 template <typename T>
489 class CSRMatMulGPUOp : public CSRMatMulOp<GPUDevice, T> {
490   using SparseMatrix = Eigen::SparseMatrix<T, Eigen::RowMajor>;
491   using Matrix =
492       Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
493   using ConstMatrixMap = Eigen::Map<const Matrix>;
494   using MatrixMap = Eigen::Map<Matrix>;
495 
496  public:
CSRMatMulGPUOp(OpKernelConstruction * c)497   explicit CSRMatMulGPUOp(OpKernelConstruction* c)
498       : CSRMatMulOp<GPUDevice, T>(c) {}
499 
~CSRMatMulGPUOp()500   ~CSRMatMulGPUOp() override {}
501 
Compute(OpKernelContext * ctx)502   void Compute(OpKernelContext* ctx) final {
503     const CSRSparseMatrix* a_matrix;
504     OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &a_matrix));
505     const Tensor& b_t = ctx->input(1);
506 
507     int rank;
508     int64_t batch_size;
509     OP_REQUIRES_OK(ctx,
510                    this->ValidateInputs(*a_matrix, b_t, &rank, &batch_size));
511 
512     const Tensor& a_dense_shape_t = a_matrix->dense_shape();
513     TensorShape a_dense_tensor_shape;
514     auto a_dense_shape = a_dense_shape_t.vec<int64>();
515     OP_REQUIRES_OK(
516         ctx, TensorShapeUtils::MakeShape(a_dense_shape, &a_dense_tensor_shape));
517 
518     const int row_dim = (rank == 2) ? 0 : 1;
519     const int64_t a_outer_dim = a_dense_tensor_shape.dim_size(
520         this->transpose_a_ ? row_dim + 1 : row_dim);
521     const int64_t b_inner_dim =
522         b_t.shape().dim_size(this->transpose_b_ ? row_dim + 1 : row_dim);
523     const int64_t b_outer_dim =
524         b_t.dim_size(this->transpose_b_ ? row_dim : row_dim + 1);
525     const int64_t b_slice_size = b_inner_dim * b_outer_dim;
526 
527     TensorShape c_shape;
528     if (rank == 3) c_shape.AddDim(batch_size);
529     if (this->transpose_output_) {
530       c_shape.AddDim(b_outer_dim);
531       c_shape.AddDim(a_outer_dim);
532     } else {
533       c_shape.AddDim(a_outer_dim);
534       c_shape.AddDim(b_outer_dim);
535     }
536 
537     const int64_t c_matrix_lhs = c_shape.dim_size(row_dim);
538     const int64_t c_matrix_rhs = c_shape.dim_size(row_dim + 1);
539     const int64_t c_slice_size = c_matrix_lhs * c_matrix_rhs;
540     Tensor* c_t;
541     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, c_shape, &c_t));
542 
543     const GPUDevice& d = ctx->eigen_device<GPUDevice>();
544     bool use_matrix_vector_multiply = (b_outer_dim == 1);
545 #if TENSORFLOW_USE_ROCM
546     // ROCm hipsparse does not implement csrmv with transposed input a
547     use_matrix_vector_multiply =
548         use_matrix_vector_multiply && !this->transpose_a_;
549 #endif
550     if (use_matrix_vector_multiply) {
551       // Call matrix-vector multiply if b is a vector.
552       TTypes<int64>::ConstVec a_dense_shape_comp(a_dense_shape.data() + row_dim,
553                                                  2);
554       Tensor b_conj_t;
555       const T* b_base_ptr = b_t.template flat<T>().data();
556       bool conjugate_a = this->conjugate_a_;
557       bool conjugate_output = this->conjugate_output_;
558       if (this->conjugate_b_) {
559         if (conjugate_a) {
560           // In this case we can use the identity
561           //   conj(a) * conj(b) = conj(a * b)
562           // instead of creating a conjugated copy of b.
563           conjugate_a = false;
564           conjugate_output = !conjugate_output;
565         } else {
566           OP_REQUIRES_OK(
567               ctx, ctx->forward_input_or_allocate_temp(
568                        {1}, DataTypeToEnum<T>::value, b_t.shape(), &b_conj_t));
569           functor::maybe_conj<GPUDevice, T>::run(d, b_t, &b_conj_t);
570           b_base_ptr = b_conj_t.template flat<T>().data();
571         }
572       }
573 
574       functor::CSRSparseMatrixMatVec<GPUDevice, T> csr_spmv(this->transpose_a_,
575                                                             conjugate_a);
576       for (int i = 0; i < batch_size; ++i) {
577         auto a_row_ptr = a_matrix->row_pointers_vec(i);
578         auto a_col_ind = a_matrix->col_indices_vec(i);
579         auto a_values = a_matrix->values_vec<T>(i);
580         ConstCSRComponent<T> a_comp{a_row_ptr, a_col_ind, a_values,
581                                     a_dense_shape_comp};
582         const T* b_i = b_base_ptr + i * b_slice_size;
583         T* c_i = &c_t->template flat<T>()(i * c_slice_size);
584         Status s = csr_spmv.Compute(ctx, a_comp, b_i, c_i);
585         OP_REQUIRES_OK(ctx, s);
586       }
587       if (conjugate_output) {
588         functor::maybe_conj_inplace<GPUDevice, T>::run(d, c_t);
589       }
590       return;
591     }
592 
593     functor::CSRSparseMatrixMatMul<GPUDevice, T> csr_spmmadd(
594         this->transpose_output_);
595 
596     Tensor c_mat_col_major_t;
597     if (!this->transpose_output_) {
598       // If transpose_output is false, we'll need to transpose the (col
599       // major) output of the csrgemm call to get proper (row-major)
600       // output.  Which means we need to keep a temporary buffer to
601       // store the intermediate gemm output.
602       TensorShape c_mat_col_major_shape;
603       if (rank == 2) {
604         c_mat_col_major_shape = TensorShape({c_matrix_rhs, c_matrix_lhs});
605       } else {
606         c_mat_col_major_shape =
607             TensorShape({batch_size, c_matrix_rhs, c_matrix_lhs});
608       }
609       OP_REQUIRES_OK(
610           ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
611                                   c_mat_col_major_shape, &c_mat_col_major_t));
612     }
613 
614     // If transpose_output is true, return the direct (column-major i.e.,
615     // transposed) output of the csrgemm call.  Otherwise we'll need
616     // to transpose it to row major format.
617     auto c_mat_col_major = (this->transpose_output_)
618                                ? c_t->flat<T>()
619                                : c_mat_col_major_t.flat<T>();
620 
621     // Possibly transpose a.
622     const CSRSparseMatrix* a_input_matrix;
623     // If we need to transpose a, we will store the result temporarily
624     // in the object below.
625     CSRSparseMatrix a_matrix_transposed;
626     if (!this->transpose_a_) {
627       a_input_matrix = a_matrix;
628     } else {
629       functor::CSRSparseMatrixTranspose<GPUDevice, T> transpose;
630       OP_REQUIRES_OK(ctx, transpose(ctx, this->conjugate_a_, *a_matrix,
631                                     &a_matrix_transposed));
632       a_input_matrix = &a_matrix_transposed;
633     }
634 
635     auto a_input_dense_shape = a_input_matrix->dense_shape().vec<int64>();
636 
637     // Possibly transpose b.
638     Tensor b_t_input;
639     if (!this->transpose_b_) {
640       b_t_input = b_t;
641     } else {
642       TensorShape b_t_transposed_shape;
643       if (rank == 3) {
644         b_t_transposed_shape.AddDim(batch_size);
645       }
646       b_t_transposed_shape.AddDim(b_t.dim_size(row_dim + 1));
647       b_t_transposed_shape.AddDim(b_t.dim_size(row_dim));
648       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
649                                              b_t_transposed_shape, &b_t_input));
650       const GPUDevice& d = ctx->eigen_device<GPUDevice>();
651       if (this->conjugate_b_) {
652         OP_REQUIRES_OK(ctx, DoConjugateMatrixTranspose(d, b_t /*input*/,
653                                                        &b_t_input /*output*/));
654       } else {
655         OP_REQUIRES_OK(
656             ctx, DoMatrixTranspose(d, b_t /*input*/, &b_t_input /*output*/));
657       }
658     }
659 
660     // Dense shape of a batch component of A.
661     TTypes<int64>::ConstVec a_input_dense_shape_comp(
662         a_input_dense_shape.data() + row_dim, 2);
663 
664     auto b = b_t_input.flat<T>();
665 
666     for (int i = 0; i < batch_size; ++i) {
667       auto a_row_ptr = a_input_matrix->row_pointers_vec(i);
668       auto a_col_ind = a_input_matrix->col_indices_vec(i);
669       auto a_values = a_input_matrix->values_vec<T>(i);
670       typename TTypes<T>::UnalignedConstMatrix b_i(b.data() + i * b_slice_size,
671                                                    {b_inner_dim, b_outer_dim});
672       typename TTypes<T>::UnalignedMatrix c_mat_col_major_i(
673           c_mat_col_major.data() + i * c_slice_size,
674           {c_matrix_lhs, c_matrix_rhs});
675       ConstCSRComponent<T> a_comp{a_row_ptr, a_col_ind, a_values,
676                                   a_input_dense_shape_comp};
677       Status s = csr_spmmadd.Compute(ctx, a_comp, b_i, c_mat_col_major_i);
678       OP_REQUIRES_OK(ctx, s);
679     }
680 
681     if (!this->transpose_output_) {
682       // We need to return values in row major format, so transpose
683       // the column-major values in c_mat_col_major_t to row-major output c_t.
684       OP_REQUIRES_OK(ctx, DoMatrixTranspose(d, /*input=*/c_mat_col_major_t,
685                                             /*output=*/c_t));
686     }
687     if (this->conjugate_output_) {
688       functor::maybe_conj_inplace<GPUDevice, T>::run(d, c_t);
689     }
690   }
691 };
692 
693 #define REGISTER_CPU(T)                                                     \
694   REGISTER_KERNEL_BUILDER(                                                  \
695       Name("SparseMatrixMatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
696       CSRMatMulCPUOp<T>);
697 
698 REGISTER_CPU(float)
699 REGISTER_CPU(double)
700 REGISTER_CPU(complex64)
701 REGISTER_CPU(complex128)
702 
703 #undef REGISTER_CPU
704 
705 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
706 
707 #define REGISTER_GPU(T)                                                     \
708   REGISTER_KERNEL_BUILDER(                                                  \
709       Name("SparseMatrixMatMul").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
710       CSRMatMulGPUOp<T>);
711 
712 REGISTER_GPU(float)
713 REGISTER_GPU(double)
714 REGISTER_GPU(complex64)
715 REGISTER_GPU(complex128)
716 
717 #undef REGISTER_GPU
718 
719 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
720 
721 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
722 
723 namespace functor {
724 
725 namespace {
726 
727 // GPUDataType<T>::type translates from a C++ type (e.g. float) to a
728 // GPUDataType_t (e.g. CUDA_R_32F).
729 template <typename T>
730 struct GPUDataType;
731 
732 // GPUDataType templates are currently not instantiated in the ROCm flow
733 // So leaving out the #elif TENSORFLOW_USE_ROCM blocks for now
734 // hipblas library is not (yet) being pulled in via rocm_configure.bzl
735 // so cannot reference tyeps from hipblas headers here
736 template <>
737 struct GPUDataType<Eigen::half> {
738 #if GOOGLE_CUDA
739   static constexpr cudaDataType_t type = CUDA_R_16F;
740 #endif
741 };
742 
743 template <>
744 struct GPUDataType<float> {
745 #if GOOGLE_CUDA
746   static constexpr cudaDataType_t type = CUDA_R_32F;
747 #endif
748 };
749 
750 template <>
751 struct GPUDataType<std::complex<float>> {
752 #if GOOGLE_CUDA
753   static constexpr cudaDataType_t type = CUDA_C_32F;
754 #endif
755 };
756 
757 template <>
758 struct GPUDataType<double> {
759 #if GOOGLE_CUDA
760   static constexpr cudaDataType_t type = CUDA_R_64F;
761 #endif
762 };
763 
764 template <>
765 struct GPUDataType<std::complex<double>> {
766 #if GOOGLE_CUDA
767   static constexpr cudaDataType_t type = CUDA_C_64F;
768 #endif
769 };
770 
771 }  // namespace
772 
773 template <typename T>
774 class CSRSparseMatrixMatMul<GPUDevice, T> {
775  public:
CSRSparseMatrixMatMul(const bool transpose_output)776   explicit CSRSparseMatrixMatMul(const bool transpose_output)
777       : transpose_output_(transpose_output) {}
778 
Compute(OpKernelContext * ctx,const ConstCSRComponent<T> & a,typename TTypes<T>::UnalignedConstMatrix b,typename TTypes<T>::UnalignedMatrix c)779   Status Compute(OpKernelContext* ctx, const ConstCSRComponent<T>& a,
780                  typename TTypes<T>::UnalignedConstMatrix b,
781                  typename TTypes<T>::UnalignedMatrix c) {
782     GpuSparse cuda_sparse(ctx);
783     TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
784     {
785       // Use Csrmm/SpMM to calculate:
786       //   C = alpha * op(A) * op(B) + beta * C
787       // where alpha = 1.0, beta = 0.0, A is sparse and B and C are dense.
788       // Note that Csrmm/Spmm assumes B and C are in column-major form; so we
789       // use transB == true, and manually transpose the output in place
790       // using blas<t>geam.
791       // TODO(ebrevdo,rmlarsen): Add support for transposition and adjoint.
792 
793       // Create alpha and beta scalars; alpha = 1.0, beta = 0.0
794       // TODO(ebrevdo,rmlarsen): Add support for non-trivial alpha and beta.
795       const T alpha = 1;
796       const T beta = 0;
797 
798       // A is (m, k), Bt is (ldb, k) and Ct is (ldc, n)
799       const int k = b.dimension(0);
800       DCHECK_EQ(k, a.dense_shape_host(1));
801 
802       // If transpose_output_ is true, then the c matrix we receive
803       // here is the direct row major output (into which we will store
804       // csrgemm's col major output).  Otherwise it's a
805       // temporary tensor that will store the column major output that
806       // will eventually be transposed.
807       const int m = c.dimension(transpose_output_ ? 1 : 0);
808       const int n = c.dimension(transpose_output_ ? 0 : 1);
809       DCHECK_EQ(m, a.dense_shape_host(0));
810       DCHECK_EQ(n, b.dimension(1));
811       const int nnz = a.values.size();
812       DCHECK_EQ(nnz, a.col_ind.size());
813 
814       // ldb: leading dimension of B. If op(B)=B, it must be at least max(1, k)
815       // if op(A) = A and at least max (1, m) otherwise. If op(B) != B, it must
816       // be at least max(1, n).
817       const int ldb = n;
818       // ldc: leading dimension of C. It must be at least max(1, m) if
819       // op(A) = A and at least max(1, k) otherwise.
820       const int ldc = m;
821 
822       // transA must be non-transpose if transB is transpose (cusparse
823       // limitation).
824 #if GOOGLE_CUDA
825       const gpusparseOperation_t transA = CUSPARSE_OPERATION_NON_TRANSPOSE;
826 #elif TENSORFLOW_USE_ROCM
827       const gpusparseOperation_t transA = HIPSPARSE_OPERATION_NON_TRANSPOSE;
828 #endif
829 
830       // transB: b is row-major, and cusparse requires col-major b (or
831       // equivalently transB == transpose).  this version is actually more
832       // efficient.
833 #if GOOGLE_CUDA && CUDA_VERSION >= 10020
834 
835       const gpusparseOperation_t transB = CUSPARSE_OPERATION_TRANSPOSE;
836       gpusparseSpMatDescr_t matA;
837       gpusparseDnMatDescr_t matB, matC;
838 
839       TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsr(
840           &matA, m, k, nnz, const_cast<int*>(a.row_ptr.data()),
841           const_cast<int*>(a.col_ind.data()), const_cast<T*>(a.values.data()),
842           CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO,
843           GPUDataType<T>::type));
844 
845       TF_RETURN_IF_GPUSPARSE_ERROR(
846           cusparseCreateDnMat(&matB, n, k, ldb, const_cast<T*>(b.data()),
847                               GPUDataType<T>::type, CUSPARSE_ORDER_COL));
848 
849       TF_RETURN_IF_GPUSPARSE_ERROR(
850           cusparseCreateDnMat(&matC, m, n, ldc, c.data(), GPUDataType<T>::type,
851                               CUSPARSE_ORDER_COL));
852 
853       size_t bufferSize = 0;
854       TF_RETURN_IF_ERROR(cuda_sparse.SpMMBufferSize(
855           transA, transB, &alpha, matA, matB, &beta, matC,
856           CUSPARSE_MM_ALG_DEFAULT, &bufferSize));
857 
858       Tensor buffer;
859       TF_RETURN_IF_ERROR(ctx->allocate_temp(
860           DT_INT8, TensorShape({static_cast<int64>(bufferSize)}), &buffer));
861       DCHECK(buffer.flat<int8>().data() != nullptr);
862 
863       TF_RETURN_IF_ERROR(cuda_sparse.SpMM(transA, transB, &alpha, matA, matB,
864                                           &beta, matC, CUSPARSE_MM_ALG_DEFAULT,
865                                           buffer.flat<int8>().data()));
866 
867       TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnMat(matB));
868       TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnMat(matC));
869       TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroySpMat(matA));
870 
871 #elif TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 40200
872       // Use SPMM
873       const gpusparseOperation_t transB = HIPSPARSE_OPERATION_TRANSPOSE;
874       gpusparseSpMatDescr_t matA;
875       gpusparseDnMatDescr_t matB, matC;
876 
877       TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateCsr(
878           &matA, m, k, nnz, const_cast<int*>(a.row_ptr.data()),
879           const_cast<int*>(a.col_ind.data()), const_cast<T*>(a.values.data()),
880           CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO,
881           GPUDataType<T>::type));
882 
883       TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateDnMat(
884           &matB, n, k, ldb, const_cast<T*>(b.data()), GPUDataType<T>::type,
885           HIPSPARSE_ORDER_COL));
886 
887       TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateDnMat(
888           &matC, m, n, ldc, c.data(), GPUDataType<T>::type,
889           HIPSPARSE_ORDER_COL));
890 
891       size_t bufferSize = 0;
892       TF_RETURN_IF_ERROR(cuda_sparse.SpMMBufferSize(
893           transA, transB, &alpha, matA, matB, &beta, matC,
894           HIPSPARSE_MM_ALG_DEFAULT, &bufferSize));
895 
896       Tensor buffer;
897       TF_RETURN_IF_ERROR(ctx->allocate_temp(
898           DT_INT8, TensorShape({static_cast<int64>(bufferSize)}), &buffer));
899       DCHECK(buffer.flat<int8>().data() != nullptr);
900 
901       TF_RETURN_IF_ERROR(cuda_sparse.SpMM(transA, transB, &alpha, matA, matB,
902                                           &beta, matC, CUSPARSE_MM_ALG_DEFAULT,
903                                           buffer.flat<int8>().data()));
904 
905       TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseDestroyDnMat(matB));
906       TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseDestroyDnMat(matC));
907       TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseDestroySpMat(matA));
908 
909 #else
910 
911 #if GOOGLE_CUDA
912 
913       const gpusparseOperation_t transB = CUSPARSE_OPERATION_TRANSPOSE;
914 
915       gpusparseMatDescr_t descrA;
916       TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA));
917       TF_RETURN_IF_GPUSPARSE_ERROR(
918           cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL));
919       TF_RETURN_IF_GPUSPARSE_ERROR(
920           cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO));
921 
922 #elif TENSORFLOW_USE_ROCM
923 
924       const gpusparseOperation_t transB = HIPSPARSE_OPERATION_TRANSPOSE;
925 
926       gpusparseMatDescr_t descrA;
927       TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateMatDescr(&descrA));
928       TF_RETURN_IF_GPUSPARSE_ERROR(
929           wrap::hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL));
930       TF_RETURN_IF_GPUSPARSE_ERROR(
931           wrap::hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
932 #endif  // GOOGLE_CUDA
933 
934       TF_RETURN_IF_ERROR(
935           cuda_sparse.Csrmm(transA, transB, m, n, k, nnz, &alpha, descrA,
936                             a.values.data(), a.row_ptr.data(), a.col_ind.data(),
937                             b.data(), ldb, &beta, c.data(), ldc));
938 
939 #endif  // GOOGLE_CUDA && CUDA_VERSION >= 10020
940     }
941 
942     return Status::OK();
943   }
944 
945  private:
946   bool transpose_output_;
947 };
948 
949 template <typename T>
950 class CSRSparseMatrixMatVec<GPUDevice, T> {
951  public:
CSRSparseMatrixMatVec(bool transpose_a,bool conjugate_a)952   CSRSparseMatrixMatVec(bool transpose_a, bool conjugate_a)
953       : transA_(TransposeAndConjugateToGpuSparseOp(transpose_a, conjugate_a,
954                                                    &status_)) {}
955 
Compute(OpKernelContext * ctx,const ConstCSRComponent<T> & a,const T * x,T * y)956   Status Compute(OpKernelContext* ctx, const ConstCSRComponent<T>& a,
957                  const T* x, T* y) {
958     TF_RETURN_IF_ERROR(status_);
959     GpuSparse cuda_sparse(ctx);
960     TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
961     {
962       // Use Csrmv to calculate:
963       //   y = alpha * op(A) * x + beta * y
964       // where alpha = 1.0, beta = 0.0, A is a sparse matrix and x and y are
965       // dense vectors.
966 
967       // Create alpha and beta scalars; alpha = 1.0, beta = 0.0
968       // TODO(rmlarsen,ebrevdo): Add support for general alpha, beta.
969       const T alpha = 1;
970       const T beta = 0;
971 
972 #if GOOGLE_CUDA && CUDA_VERSION < 10020
973       gpusparseMatDescr_t descrA;
974       TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA));
975       TF_RETURN_IF_GPUSPARSE_ERROR(
976           cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL));
977       TF_RETURN_IF_GPUSPARSE_ERROR(
978           cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO));
979 #elif TENSORFLOW_USE_ROCM
980       gpusparseMatDescr_t descrA;
981       TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateMatDescr(&descrA));
982       TF_RETURN_IF_GPUSPARSE_ERROR(
983           wrap::hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL));
984       TF_RETURN_IF_GPUSPARSE_ERROR(
985           wrap::hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
986 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
987 
988       const int m = a.dense_shape_host(0);
989       const int n = a.dense_shape_host(1);
990       const int nnz = a.values.size();
991       DCHECK_EQ(nnz, a.col_ind.size());
992 #if GOOGLE_CUDA && (CUDA_VERSION >= 10020)
993       TF_RETURN_IF_ERROR(cuda_sparse.Csrmv(transA_, m, n, nnz, &alpha,
994                                            a.values.data(), a.row_ptr.data(),
995                                            a.col_ind.data(), x, &beta, y));
996 #else
997       TF_RETURN_IF_ERROR(cuda_sparse.Csrmv(transA_, m, n, nnz, &alpha, descrA,
998                                            a.values.data(), a.row_ptr.data(),
999                                            a.col_ind.data(), x, &beta, y));
1000 #endif
1001     }
1002 
1003     return Status::OK();
1004   }
1005 
1006  private:
1007   Status status_;
1008   const gpusparseOperation_t transA_;
1009 };
1010 
1011 }  // namespace functor
1012 
1013 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1014 
1015 }  // namespace tensorflow
1016