• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
19 #define EIGEN_USE_GPU
20 #endif
21 
22 #include "third_party/eigen3/Eigen/Core"
23 #include "third_party/eigen3/Eigen/SparseCore"
24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/tensor_types.h"
28 #include "tensorflow/core/framework/variant_op_registry.h"
29 #include "tensorflow/core/kernels/cwise_ops_common.h"
30 #include "tensorflow/core/kernels/dense_update_functor.h"
31 #include "tensorflow/core/kernels/fill_functor.h"
32 #include "tensorflow/core/kernels/sparse/kernels.h"
33 #include "tensorflow/core/kernels/sparse/sparse_matrix.h"
34 #include "tensorflow/core/kernels/sparse/transpose_op.h"
35 #include "tensorflow/core/kernels/transpose_functor.h"
36 #include "tensorflow/core/lib/gtl/inlined_vector.h"
37 #include "tensorflow/core/platform/threadpool.h"
38 
39 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
40 #include "tensorflow/core/util/cuda_solvers.h"
41 #include "tensorflow/core/util/cuda_sparse.h"
42 #endif
43 
44 namespace tensorflow {
45 
46 // TODO(anudhyan): These constants may be tuned based on the performance of
47 // 'benchmark_sparse_matrix_mat_vec_mul'. We would like to find constants
48 // which work across hardware platforms for typical matrix sizes. It should be
49 // possible to observe at least 30-50% improvement as we increase the number
50 // of threads by 1. If not, then it may we worth increasing kMaxShards and
51 // kNumShardsPerThread. However, once we have too many shards, latency may be
52 // dominated by per-shard overhead.
53 //
54 // Maximum number of shards into which to divide the computation for each CSR
55 // Sparse Matrix instance.
56 static constexpr int32 kMaxShards = 20;
57 // Number of shards allocated to each thread.
58 static constexpr int32 kNumShardsPerThread = 3;
59 
60 typedef Eigen::ThreadPoolDevice CPUDevice;
61 typedef Eigen::GpuDevice GPUDevice;
62 
63 // Abstract OpKernel to compute sparse-dense matrix multiplication.
64 //
65 // Implements a kernel which, given a SparseMatrix `a` and dense Tensor `b`,
66 // computes a dense Tensor `c` satisfying `c = a * b` where * denotes matrix
67 // multiplication.
68 //
69 // The boolean attributes `transpose_a` and `adjoint_a` will transpose or
70 // adjoint `a` before multiplication, respectively. At most one of these
71 // attributes must be set to True. Corresponding attributes will transpose or
72 // adjoint `b` or the output (after multiplication).
73 //
74 // The rank of both `a` and `b` must be equal and their shapes must be
75 // compatible for matrix multiplication. Otherwise, InvalidArgument runtime
76 // errors will be thrown. Only rank 2 or rank 3 inputs are supported.
77 //
78 template <typename Device, typename T>
79 class CSRMatMulOp : public OpKernel {
80  public:
CSRMatMulOp(OpKernelConstruction * c)81   explicit CSRMatMulOp(OpKernelConstruction* c) : OpKernel(c) {
82     OP_REQUIRES_OK(c, c->GetAttr("transpose_a", &transpose_a_));
83     OP_REQUIRES_OK(c, c->GetAttr("transpose_b", &transpose_b_));
84     bool adjoint_a;
85     OP_REQUIRES_OK(c, c->GetAttr("adjoint_a", &adjoint_a));
86     OP_REQUIRES(c, !(adjoint_a && transpose_a_),
87                 errors::InvalidArgument(
88                     "Only one of adjoint_a and transpose_a may be true."));
89     bool adjoint_b;
90     OP_REQUIRES_OK(c, c->GetAttr("adjoint_b", &adjoint_b));
91     OP_REQUIRES(c, !(adjoint_b && transpose_b_),
92                 errors::InvalidArgument(
93                     "Only one of adjoint_b and transpose_b may be true."));
94     OP_REQUIRES_OK(c, c->GetAttr("transpose_output", &transpose_output_));
95     OP_REQUIRES_OK(c, c->GetAttr("conjugate_output", &conjugate_output_));
96     conjugate_a_ = adjoint_a;
97     conjugate_b_ = adjoint_b;
98     transpose_a_ |= adjoint_a;
99     transpose_b_ |= adjoint_b;
100   }
101 
~CSRMatMulOp()102   ~CSRMatMulOp() override {}
103 
ValidateInputs(const CSRSparseMatrix & sparse_matrix_a,const Tensor & dense_tensor_b,int * rank,int64 * batch_size)104   Status ValidateInputs(const CSRSparseMatrix& sparse_matrix_a,
105                         const Tensor& dense_tensor_b, int* rank,
106                         int64* batch_size) {
107     if (sparse_matrix_a.dtype() != dense_tensor_b.dtype()) {
108       return errors::InvalidArgument(
109           "Input types don't match.  a.dtype == ",
110           DataTypeString(sparse_matrix_a.dtype()),
111           " vs. b.dtype == ", DataTypeString(dense_tensor_b.dtype()));
112     }
113     *rank = sparse_matrix_a.dims();
114     // TODO(ebrevdo): Add support for broadcasting matmul.
115     if (*rank != dense_tensor_b.dims()) {
116       return errors::InvalidArgument("Ranks of a and b must match, saw: ", rank,
117                                      " vs. ", dense_tensor_b.dims(), ".");
118     }
119     // A valid CSR SparseMatrix has rank 2 or rank 3.
120     *batch_size = (*rank == 2) ? 1 : dense_tensor_b.dim_size(0);
121     if (sparse_matrix_a.batch_size() != *batch_size) {
122       return errors::InvalidArgument("Batch sizes of a and b must match, saw: ",
123                                      sparse_matrix_a.batch_size(), " vs. ",
124                                      batch_size, ".");
125     }
126     const auto& a_dense_shape = sparse_matrix_a.dense_shape().vec<int64>();
127     const int64 a_inner_dim =
128         a_dense_shape(this->transpose_a_ ? *rank - 2 : *rank - 1);
129     const int64 b_inner_dim =
130         dense_tensor_b.dim_size(this->transpose_b_ ? *rank - 1 : *rank - 2);
131     if (a_inner_dim != b_inner_dim) {
132       return errors::InvalidArgument(
133           "Inner product dimensions of A and B do not agree.  Shapes are: ",
134           TensorShape(a_dense_shape), " vs. ",
135           dense_tensor_b.shape().DebugString());
136     }
137     return Status::OK();
138   }
139 
140  public:
141   bool transpose_a_;
142   bool transpose_b_;
143   bool conjugate_a_;
144   bool conjugate_b_;
145   bool transpose_output_;
146   bool conjugate_output_;
147 };
148 
149 // CPU Kernel to compute sparse-dense matrix multiplication.
150 //
151 // Uses Eigen SparseMatrix to compute the sparse-dense multiplication between
152 // a CSR SparseMatrix `a` and dense Tensor `b`. If intra-op parallelism is
153 // available, the implementation parallelizes the computation across each row
154 // of the sparse matrix.
155 template <typename T>
156 class CSRMatMulCPUOp : public CSRMatMulOp<CPUDevice, T> {
157   using SparseMatrix = Eigen::SparseMatrix<T, Eigen::RowMajor>;
158   using Matrix =
159       Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
160   using ConstMatrixMap = Eigen::Map<const Matrix>;
161   using MatrixMap = Eigen::Map<Matrix>;
162 
163  public:
CSRMatMulCPUOp(OpKernelConstruction * c)164   explicit CSRMatMulCPUOp(OpKernelConstruction* c)
165       : CSRMatMulOp<CPUDevice, T>(c) {}
166 
~CSRMatMulCPUOp()167   ~CSRMatMulCPUOp() override {}
168 
Compute(OpKernelContext * ctx)169   void Compute(OpKernelContext* ctx) final {
170     const CSRSparseMatrix* sparse_matrix_a;
171     OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &sparse_matrix_a));
172     const Tensor& matrix_b = ctx->input(1);
173 
174     int rank;
175     int64 batch_size;
176     OP_REQUIRES_OK(ctx, this->ValidateInputs(*sparse_matrix_a, matrix_b, &rank,
177                                              &batch_size));
178 
179     const auto dense_shape = sparse_matrix_a->dense_shape().vec<int64>();
180     int64 num_lhs_rows = dense_shape(rank - 2);
181     int64 num_lhs_cols = dense_shape(rank - 1);
182     int64 num_rhs_rows = matrix_b.dim_size(rank - 2);
183     int64 num_rhs_cols = matrix_b.dim_size(rank - 1);
184 
185     if (this->transpose_a_) {
186       std::swap(num_lhs_rows, num_lhs_cols);
187     }
188 
189     // Possibly transpose the dense Tensor b.
190     const Tensor* rhs = &matrix_b;
191     Tensor b_transposed;
192     if (this->transpose_b_) {
193       OP_REQUIRES_OK(
194           ctx, TransposeAndConjugateTensor(ctx, matrix_b, this->conjugate_b_,
195                                            &b_transposed));
196       rhs = &b_transposed;
197       std::swap(num_rhs_rows, num_rhs_cols);
198     }
199 
200     // If we're transposing the output, then allocate a temporary buffer to
201     // store the output. Otherwise allocate the output directly.
202     Tensor* output = nullptr;
203     Tensor* matmul_result = nullptr;
204     Tensor output_transposed;
205     OP_REQUIRES_OK(
206         ctx, AllocateOutput(ctx, rank, batch_size, num_lhs_rows, num_rhs_cols,
207                             this->transpose_output_, &output,
208                             &output_transposed, &matmul_result));
209 
210     if (!this->transpose_a_) {
211       SparseDenseMatMulWithoutTransposedLHS(
212           ctx, batch_size, num_lhs_rows, *sparse_matrix_a, *rhs, matmul_result);
213     } else {  // transpose_a_ == true
214       SparseDenseMatMulWithTransposedLHS(ctx, batch_size, num_lhs_rows,
215                                          num_lhs_cols, *sparse_matrix_a, *rhs,
216                                          matmul_result);
217     }
218 
219     // Transpose (and conjugate) the output if necessary.
220     // Note that conjugate is only true if transpose is also true.
221     if (this->transpose_output_) {
222       OP_REQUIRES_OK(
223           ctx, TransposeAndConjugateAllocatedTensor(
224                    ctx, output_transposed, this->conjugate_output_, output));
225     } else if (this->conjugate_output_) {
226       functor::maybe_conj_inplace<CPUDevice, T>::run(
227           ctx->eigen_device<CPUDevice>(), output);
228     }
229   }
230 
231  private:
232   // Allocates the output with the appropriate shape. Additionally, if
233   // transpose_output is True, allocates a temporary buffer with the transposed
234   // output. 'matmul_result' points to either output or output_transposed, based
235   // on whether transpose_output is True.
AllocateOutput(OpKernelContext * ctx,const int32 rank,const int64 batch_size,const int64 num_rows,const int64 num_cols,const bool transpose_output,Tensor ** output,Tensor * output_transposed,Tensor ** matmul_result)236   Status AllocateOutput(OpKernelContext* ctx, const int32 rank,
237                         const int64 batch_size, const int64 num_rows,
238                         const int64 num_cols, const bool transpose_output,
239                         Tensor** output, Tensor* output_transposed,
240                         Tensor** matmul_result) {
241     TensorShape output_shape;
242     if (rank == 3) output_shape.AddDim(batch_size);
243 
244     if (!transpose_output) {
245       output_shape.AppendShape({num_rows, num_cols});
246       TF_RETURN_IF_ERROR(ctx->allocate_output(0, output_shape, output));
247       *matmul_result = *output;
248     } else {
249       TensorShape output_transposed_shape = output_shape;
250       output_transposed_shape.AppendShape({num_rows, num_cols});
251       output_shape.AppendShape({num_cols, num_rows});
252       TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value,
253                                             output_transposed_shape,
254                                             output_transposed));
255       TF_RETURN_IF_ERROR(ctx->allocate_output(0, output_shape, output));
256       *matmul_result = output_transposed;
257     }
258     return Status::OK();
259   }
260 
261   // Returns an Eigen::Ref expression of a sparse sub-matrix from the given
262   // contiguous segment of rows of the CSR Sparse Matrix.
GetSparseMatrixRef(const CSRSparseMatrix & csr_matrix,const int batch_index,const int64 row_begin,const int64 num_shard_rows,std::vector<int32> * row_ptrs)263   Eigen::Ref<const SparseMatrix> GetSparseMatrixRef(
264       const CSRSparseMatrix& csr_matrix, const int batch_index,
265       const int64 row_begin, const int64 num_shard_rows,
266       std::vector<int32>* row_ptrs) {
267     // Compute the row pointers of the sparse sub-matrix.
268     row_ptrs->resize(num_shard_rows + 1);
269     const int64 row_offset =
270         csr_matrix.row_pointers_vec(batch_index)(row_begin);
271     for (int64 row_idx = 0; row_idx <= num_shard_rows; ++row_idx) {
272       row_ptrs->at(row_idx) =
273           csr_matrix.row_pointers_vec(batch_index)(row_begin + row_idx) -
274           row_offset;
275     }
276     const int64 num_cols =
277         csr_matrix.dense_shape().vec<int64>()(csr_matrix.dims() - 1);
278     return Eigen::Map<const SparseMatrix>(
279         num_shard_rows /* num_rows */, num_cols /* num_cols */,
280         row_ptrs->at(num_shard_rows) /* total_nnz */, row_ptrs->data(),
281         csr_matrix.col_indices_vec(batch_index).data() + row_offset,
282         csr_matrix.values_vec<T>(batch_index).data() + row_offset);
283   }
284 
285   // Sparse-Dense Matrix Multiplication between a CSRSparseMatrix (LHS) and a
286   // dense Tensor (RHS).
SparseDenseMatMulWithoutTransposedLHS(OpKernelContext * ctx,const int64 batch_size,const int64 num_lhs_rows,const CSRSparseMatrix & lhs,const Tensor & rhs,Tensor * output)287   void SparseDenseMatMulWithoutTransposedLHS(
288       OpKernelContext* ctx, const int64 batch_size, const int64 num_lhs_rows,
289       const CSRSparseMatrix& lhs, const Tensor& rhs, Tensor* output) {
290     // Parallelize matrix multiplication across batch dimensions and across
291     // rows in each batch.
292     auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
293     const int32 num_threads = worker_threads.num_threads;
294     const int64 block_size =
295         num_lhs_rows / std::max(kMaxShards, kNumShardsPerThread * num_threads);
296     const int64 num_rhs_rows = rhs.dim_size(rhs.dims() - 2);
297     const int64 num_rhs_cols = rhs.dim_size(rhs.dims() - 1);
298     worker_threads.workers->ParallelFor(
299         batch_size * num_lhs_rows /* total */,
300         thread::ThreadPool::SchedulingParams(
301             thread::ThreadPool::SchedulingStrategy::
302                 kFixedBlockSize /* strategy */,
303             absl::nullopt /* cost_per_unit */, block_size),
304         [&](int64 batch_and_row_begin, int64 batch_and_row_end) {
305           HandleBatchAndRowRange(
306               num_lhs_rows, batch_and_row_begin, batch_and_row_end,
307               [&](int64 batch_idx, int64 row_begin, int64 row_end) {
308                 const int64 num_shard_rows = row_end - row_begin;
309 
310                 // Define an Eigen::SparseMatrix over the row range:
311                 // [row_begin, row_end) of the CSR SparseMatrix A.
312                 std::vector<int32> row_ptrs;
313                 auto sparse_matrix = GetSparseMatrixRef(
314                     lhs, batch_idx, row_begin, num_shard_rows, &row_ptrs);
315 
316                 // Map the corresponding rows of the rhs.
317                 ConstMatrixMap rhs_map(rhs.flat<T>().data() + batch_idx *
318                                                                   num_rhs_rows *
319                                                                   num_rhs_cols,
320                                        num_rhs_rows, num_rhs_cols);
321 
322                 // Write to the corresponding rows of the output matrix.
323                 MatrixMap output_map(
324                     output->flat<T>().data() +
325                         batch_idx * num_lhs_rows * num_rhs_cols +
326                         row_begin * num_rhs_cols,
327                     num_shard_rows, num_rhs_cols);
328                 output_map.noalias() = sparse_matrix * rhs_map;
329               });
330         });
331   }
332 
333   // Sparse-Dense Matrix Multiplication assuming the CSRSparseMatrix (LHS) is
334   // to be transposed before the operation.
SparseDenseMatMulWithTransposedLHS(OpKernelContext * ctx,const int64 batch_size,const int64 num_lhs_rows,const int64 num_lhs_cols,const CSRSparseMatrix & lhs,const Tensor & rhs,Tensor * output)335   void SparseDenseMatMulWithTransposedLHS(OpKernelContext* ctx,
336                                           const int64 batch_size,
337                                           const int64 num_lhs_rows,
338                                           const int64 num_lhs_cols,
339                                           const CSRSparseMatrix& lhs,
340                                           const Tensor& rhs, Tensor* output) {
341     auto device = ctx->eigen_device<CPUDevice>();
342     auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
343     const int32 num_threads = worker_threads.num_threads;
344     const int64 num_rhs_rows = rhs.dim_size(rhs.dims() - 2);
345     const int64 num_rhs_cols = rhs.dim_size(rhs.dims() - 1);
346     // Usually, we want to avoid transposing the sparse matrix A since it may be
347     // an expensive operation. Instead, we use the identity (A^T B) = (B^T A)^T.
348     // We don't actually transpose B or the output because it is more convenient
349     // to have them in column major form.
350     //
351     // However, if A is hypersparse and B and C are huge, transposing A will be
352     // cheaper. In the future, we should have a cost model estimating the cost
353     // of transposing all matrices (A, B, C) to decide which variant to use.
354 
355     // Each thread writes to its own copy of the matrix product. These
356     // `num_threads` copies are summed together to obtain the final result.
357     Tensor matmul_result_buffer;
358     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
359                                            TensorShape({num_threads + 1,
360                                                         output->NumElements()}),
361                                            &matmul_result_buffer));
362     functor::SetZeroFunctor<CPUDevice, T> set_zero;
363     set_zero(device, matmul_result_buffer.flat<T>());
364 
365     // Parallelize matrix multiplication across batch dimensions and across
366     // columns of A^T in each batch. These correspond to rows of A.
367     const int64 block_size =
368         num_lhs_cols / std::max(kMaxShards, kNumShardsPerThread * num_threads);
369     worker_threads.workers->ParallelForWithWorkerId(
370         batch_size * num_lhs_cols /* total */,
371         thread::ThreadPool::SchedulingParams(
372             thread::ThreadPool::SchedulingStrategy::
373                 kFixedBlockSize /* strategy */,
374             absl::nullopt /* cost_per_unit */, block_size),
375         [&](int64 batch_and_row_begin, int64 batch_and_row_end, int tid) {
376           HandleBatchAndRowRange(
377               num_lhs_cols, batch_and_row_begin, batch_and_row_end,
378               [&](int64 batch_idx, int64 row_begin, int64 row_end) {
379                 const int64 num_shard_rows = row_end - row_begin;
380 
381                 // Define a new sparse sub-matrix from the row range
382                 // [row_begin, row_end) of the sparse matrix A.
383                 std::vector<int32> row_ptrs;
384                 auto sparse_matrix = GetSparseMatrixRef(
385                     lhs, batch_idx, row_begin, num_shard_rows, &row_ptrs);
386 
387                 // Map the corresponding `num_shard_rows` columns of B^T.
388                 // This is the same as taking the `num_shard_rows` rows of B.
389                 ConstMatrixMap b_dense_map(
390                     rhs.flat<T>().data() +
391                         batch_idx * num_rhs_rows * num_rhs_cols +
392                         row_begin * num_rhs_cols,
393                     num_shard_rows, num_rhs_cols);
394 
395                 // Map to the corresponding rows of the output.
396                 MatrixMap output_map(
397                     matmul_result_buffer.flat<T>().data() +
398                         tid * batch_size * num_lhs_rows * num_rhs_cols +
399                         batch_idx * num_lhs_rows * num_rhs_cols,
400                     num_lhs_rows, num_rhs_cols);
401 
402                 // Compute the product C^T = B^T * A; restricted to the row
403                 // range in the current shard.
404                 if (this->conjugate_a_) {
405                   output_map.transpose().noalias() +=
406                       b_dense_map.transpose() * sparse_matrix.conjugate();
407                 } else {
408                   output_map.transpose().noalias() +=
409                       b_dense_map.transpose() * sparse_matrix;
410                 }
411               });
412         });
413 
414     // Sum across each thread's matmul result.
415     using Reducer = Eigen::internal::SumReducer<T>;
416     using Index = typename TTypes<T>::Tensor::Index;
417     output->flat<T>().device(device) = matmul_result_buffer.matrix<T>().reduce(
418         Eigen::array<Index, 1>({0}), Reducer());
419   }
420 
421   // Given a range [batch_and_row_begin, batch_and_row_end) which is a
422   // contiguous subset of [0, num_rows * batch_size), calls the function
423   // fn(batch_idx, row_begin, row_end) for each batch index
424   // and the row range [row_begin, row_end) contained in the batch.
HandleBatchAndRowRange(const int64 num_rows,const int64 batch_and_row_begin,const int64 batch_and_row_end,const std::function<void (int64,int64,int64)> & fn)425   void HandleBatchAndRowRange(
426       const int64 num_rows, const int64 batch_and_row_begin,
427       const int64 batch_and_row_end,
428       const std::function<void(int64, int64, int64)>& fn) {
429     // Obtain the batch indices overlapping with the current shard.
430     const int64 batch_begin = batch_and_row_begin / num_rows;
431     const int64 batch_end_inclusive = batch_and_row_end / num_rows;
432 
433     for (int64 batch_idx = batch_begin; batch_idx <= batch_end_inclusive;
434          ++batch_idx) {
435       // Find the contiguous set of rows which are contained in this shard as
436       // well as the current batch. We intersect with interval [batch_idx *
437       // num_rows, (batch_idx + 1) * num_rows) which denotes the current batch.
438       const int64 current_batch_row_begin =
439           std::max(batch_and_row_begin, batch_idx * num_rows);
440       const int64 current_batch_row_end =
441           std::min(batch_and_row_end, (batch_idx + 1) * num_rows);
442 
443       const int64 row_begin = current_batch_row_begin % num_rows;
444       const int64 num_shard_rows =
445           current_batch_row_end - current_batch_row_begin;
446       // Edge case for when current_batch_row_end is the first index of a new
447       // row.
448       if (num_shard_rows == 0) continue;
449 
450       fn(batch_idx, row_begin, row_begin + num_shard_rows);
451     }
452   }
453 
454   // Transposes (and optionally, conjugates) a given Tensor. Also allocates the
455   // required memory for the output Tensor.
TransposeAndConjugateTensor(OpKernelContext * ctx,const Tensor & input,bool conjugate,Tensor * output)456   Status TransposeAndConjugateTensor(OpKernelContext* ctx, const Tensor& input,
457                                      bool conjugate, Tensor* output) {
458     TensorShape transposed_shape = input.shape();
459     transposed_shape.set_dim(input.dims() - 1,
460                              input.dim_size(input.dims() - 2));
461     transposed_shape.set_dim(input.dims() - 2,
462                              input.dim_size(input.dims() - 1));
463     TF_RETURN_IF_ERROR(
464         ctx->allocate_temp(DataTypeToEnum<T>::value, transposed_shape, output));
465     return TransposeAndConjugateAllocatedTensor(ctx, input, conjugate, output);
466   }
467 
468   // Transposes (and optionally, conjugates) a given Tensor. The output should
469   // be already allocated.
TransposeAndConjugateAllocatedTensor(OpKernelContext * ctx,const Tensor & input,bool conjugate,Tensor * output)470   Status TransposeAndConjugateAllocatedTensor(OpKernelContext* ctx,
471                                               const Tensor& input,
472                                               bool conjugate, Tensor* output) {
473     if (conjugate) {
474       TF_RETURN_IF_ERROR(DoConjugateMatrixTranspose(
475           ctx->eigen_device<CPUDevice>(), input, output));
476     } else {
477       TF_RETURN_IF_ERROR(
478           DoMatrixTranspose(ctx->eigen_device<CPUDevice>(), input, output));
479     }
480     return Status::OK();
481   }
482 };
483 
484 // GPU Kernel to compute sparse-dense matrix multiplication.
485 template <typename T>
486 class CSRMatMulGPUOp : public CSRMatMulOp<GPUDevice, T> {
487   using SparseMatrix = Eigen::SparseMatrix<T, Eigen::RowMajor>;
488   using Matrix =
489       Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
490   using ConstMatrixMap = Eigen::Map<const Matrix>;
491   using MatrixMap = Eigen::Map<Matrix>;
492 
493  public:
CSRMatMulGPUOp(OpKernelConstruction * c)494   explicit CSRMatMulGPUOp(OpKernelConstruction* c)
495       : CSRMatMulOp<GPUDevice, T>(c) {}
496 
~CSRMatMulGPUOp()497   ~CSRMatMulGPUOp() override {}
498 
Compute(OpKernelContext * ctx)499   void Compute(OpKernelContext* ctx) final {
500     const CSRSparseMatrix* a_matrix;
501     OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &a_matrix));
502     const Tensor& b_t = ctx->input(1);
503 
504     int rank;
505     int64 batch_size;
506     OP_REQUIRES_OK(ctx,
507                    this->ValidateInputs(*a_matrix, b_t, &rank, &batch_size));
508 
509     const Tensor& a_dense_shape_t = a_matrix->dense_shape();
510     TensorShape a_dense_tensor_shape;
511     auto a_dense_shape = a_dense_shape_t.vec<int64>();
512     OP_REQUIRES_OK(
513         ctx, TensorShapeUtils::MakeShape(a_dense_shape, &a_dense_tensor_shape));
514 
515     const int row_dim = (rank == 2) ? 0 : 1;
516     const int64 a_outer_dim = a_dense_tensor_shape.dim_size(
517         this->transpose_a_ ? row_dim + 1 : row_dim);
518     const int64 b_inner_dim =
519         b_t.shape().dim_size(this->transpose_b_ ? row_dim + 1 : row_dim);
520     const int64 b_outer_dim =
521         b_t.dim_size(this->transpose_b_ ? row_dim : row_dim + 1);
522     const int64 b_slice_size = b_inner_dim * b_outer_dim;
523 
524     TensorShape c_shape;
525     if (rank == 3) c_shape.AddDim(batch_size);
526     if (this->transpose_output_) {
527       c_shape.AddDim(b_outer_dim);
528       c_shape.AddDim(a_outer_dim);
529     } else {
530       c_shape.AddDim(a_outer_dim);
531       c_shape.AddDim(b_outer_dim);
532     }
533 
534     const int64 c_matrix_lhs = c_shape.dim_size(row_dim);
535     const int64 c_matrix_rhs = c_shape.dim_size(row_dim + 1);
536     const int64 c_slice_size = c_matrix_lhs * c_matrix_rhs;
537     Tensor* c_t;
538     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, c_shape, &c_t));
539 
540     const GPUDevice& d = ctx->eigen_device<GPUDevice>();
541     bool use_matrix_vector_multiply = (b_outer_dim == 1);
542 #if TENSORFLOW_USE_ROCM
543     // ROCm hipsparse does not implement csrmv with transposed input a
544     use_matrix_vector_multiply =
545         use_matrix_vector_multiply && !this->transpose_a_;
546 #endif
547     if (use_matrix_vector_multiply) {
548       // Call matrix-vector multiply if b is a vector.
549       TTypes<int64>::ConstVec a_dense_shape_comp(a_dense_shape.data() + row_dim,
550                                                  2);
551       Tensor b_conj_t;
552       const T* b_base_ptr = b_t.template flat<T>().data();
553       bool conjugate_a = this->conjugate_a_;
554       bool conjugate_output = this->conjugate_output_;
555       if (this->conjugate_b_) {
556         if (conjugate_a) {
557           // In this case we can use the identity
558           //   conj(a) * conj(b) = conj(a * b)
559           // instead of creating a conjugated copy of b.
560           conjugate_a = false;
561           conjugate_output = !conjugate_output;
562         } else {
563           OP_REQUIRES_OK(
564               ctx, ctx->forward_input_or_allocate_temp(
565                        {1}, DataTypeToEnum<T>::value, b_t.shape(), &b_conj_t));
566           functor::maybe_conj<GPUDevice, T>::run(d, b_t, &b_conj_t);
567           b_base_ptr = b_conj_t.template flat<T>().data();
568         }
569       }
570 
571       functor::CSRSparseMatrixMatVec<GPUDevice, T> csr_spmv(this->transpose_a_,
572                                                             conjugate_a);
573       for (int i = 0; i < batch_size; ++i) {
574         auto a_row_ptr = a_matrix->row_pointers_vec(i);
575         auto a_col_ind = a_matrix->col_indices_vec(i);
576         auto a_values = a_matrix->values_vec<T>(i);
577         ConstCSRComponent<T> a_comp{a_row_ptr, a_col_ind, a_values,
578                                     a_dense_shape_comp};
579         const T* b_i = b_base_ptr + i * b_slice_size;
580         T* c_i = &c_t->template flat<T>()(i * c_slice_size);
581         Status s = csr_spmv.Compute(ctx, a_comp, b_i, c_i);
582         OP_REQUIRES_OK(ctx, s);
583       }
584       if (conjugate_output) {
585         functor::maybe_conj_inplace<GPUDevice, T>::run(d, c_t);
586       }
587       return;
588     }
589 
590     functor::CSRSparseMatrixMatMul<GPUDevice, T> csr_spmmadd(
591         this->transpose_output_);
592 
593     Tensor c_mat_col_major_t;
594     if (!this->transpose_output_) {
595       // If transpose_output is false, we'll need to transpose the (col
596       // major) output of the csrgemm call to get proper (row-major)
597       // output.  Which means we need to keep a temporary buffer to
598       // store the intermediate gemm output.
599       TensorShape c_mat_col_major_shape;
600       if (rank == 2) {
601         c_mat_col_major_shape = TensorShape({c_matrix_rhs, c_matrix_lhs});
602       } else {
603         c_mat_col_major_shape =
604             TensorShape({batch_size, c_matrix_rhs, c_matrix_lhs});
605       }
606       OP_REQUIRES_OK(
607           ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
608                                   c_mat_col_major_shape, &c_mat_col_major_t));
609     }
610 
611     // If transpose_output is true, return the direct (column-major i.e.,
612     // transposed) output of the csrgemm call.  Otherwise we'll need
613     // to transpose it to row major format.
614     auto c_mat_col_major = (this->transpose_output_)
615                                ? c_t->flat<T>()
616                                : c_mat_col_major_t.flat<T>();
617 
618     // Possibly transpose a.
619     const CSRSparseMatrix* a_input_matrix;
620     // If we need to transpose a, we will store the result temporarily
621     // in the object below.
622     CSRSparseMatrix a_matrix_transposed;
623     if (!this->transpose_a_) {
624       a_input_matrix = a_matrix;
625     } else {
626       functor::CSRSparseMatrixTranspose<GPUDevice, T> transpose;
627       OP_REQUIRES_OK(ctx, transpose(ctx, this->conjugate_a_, *a_matrix,
628                                     &a_matrix_transposed));
629       a_input_matrix = &a_matrix_transposed;
630     }
631 
632     auto a_input_dense_shape = a_input_matrix->dense_shape().vec<int64>();
633 
634     // Possibly transpose b.
635     Tensor b_t_input;
636     if (!this->transpose_b_) {
637       b_t_input = b_t;
638     } else {
639       TensorShape b_t_transposed_shape;
640       if (rank == 3) {
641         b_t_transposed_shape.AddDim(batch_size);
642       }
643       b_t_transposed_shape.AddDim(b_t.dim_size(row_dim + 1));
644       b_t_transposed_shape.AddDim(b_t.dim_size(row_dim));
645       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
646                                              b_t_transposed_shape, &b_t_input));
647       const GPUDevice& d = ctx->eigen_device<GPUDevice>();
648       if (this->conjugate_b_) {
649         OP_REQUIRES_OK(ctx, DoConjugateMatrixTranspose(d, b_t /*input*/,
650                                                        &b_t_input /*output*/));
651       } else {
652         OP_REQUIRES_OK(
653             ctx, DoMatrixTranspose(d, b_t /*input*/, &b_t_input /*output*/));
654       }
655     }
656 
657     // Dense shape of a batch component of A.
658     TTypes<int64>::ConstVec a_input_dense_shape_comp(
659         a_input_dense_shape.data() + row_dim, 2);
660 
661     auto b = b_t_input.flat<T>();
662 
663     for (int i = 0; i < batch_size; ++i) {
664       auto a_row_ptr = a_input_matrix->row_pointers_vec(i);
665       auto a_col_ind = a_input_matrix->col_indices_vec(i);
666       auto a_values = a_input_matrix->values_vec<T>(i);
667       typename TTypes<T>::UnalignedConstMatrix b_i(b.data() + i * b_slice_size,
668                                                    {b_inner_dim, b_outer_dim});
669       typename TTypes<T>::UnalignedMatrix c_mat_col_major_i(
670           c_mat_col_major.data() + i * c_slice_size,
671           {c_matrix_lhs, c_matrix_rhs});
672       ConstCSRComponent<T> a_comp{a_row_ptr, a_col_ind, a_values,
673                                   a_input_dense_shape_comp};
674       Status s = csr_spmmadd.Compute(ctx, a_comp, b_i, c_mat_col_major_i);
675       OP_REQUIRES_OK(ctx, s);
676     }
677 
678     if (!this->transpose_output_) {
679       // We need to return values in row major format, so transpose
680       // the column-major values in c_mat_col_major_t to row-major output c_t.
681       OP_REQUIRES_OK(ctx, DoMatrixTranspose(d, /*input=*/c_mat_col_major_t,
682                                             /*output=*/c_t));
683     }
684     if (this->conjugate_output_) {
685       functor::maybe_conj_inplace<GPUDevice, T>::run(d, c_t);
686     }
687   }
688 };
689 
690 #define REGISTER_CPU(T)                                                     \
691   REGISTER_KERNEL_BUILDER(                                                  \
692       Name("SparseMatrixMatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
693       CSRMatMulCPUOp<T>);
694 
695 REGISTER_CPU(float)
696 REGISTER_CPU(double)
697 REGISTER_CPU(complex64)
698 REGISTER_CPU(complex128)
699 
700 #undef REGISTER_CPU
701 
702 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
703 
704 #define REGISTER_GPU(T)                                                     \
705   REGISTER_KERNEL_BUILDER(                                                  \
706       Name("SparseMatrixMatMul").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
707       CSRMatMulGPUOp<T>);
708 
709 REGISTER_GPU(float)
710 REGISTER_GPU(double)
711 #if GOOGLE_CUDA
712 REGISTER_GPU(complex64)
713 REGISTER_GPU(complex128)
714 #endif
715 
716 #undef REGISTER_GPU
717 
718 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
719 
720 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
721 
722 namespace functor {
723 
724 namespace {
725 
726 // GPUDataType<T>::type translates from a C++ type (e.g. float) to a
727 // GPUDataType_t (e.g. CUDA_R_32F).
728 template <typename T>
729 struct GPUDataType;
730 
731 // GPUDataType templates are currently not instantiated in the ROCm flow
732 // So leaving out the #elif TENSORFLOW_USE_ROCM blocks for now
733 // hipblas library is not (yet) being pulled in via rocm_configure.bzl
734 // so cannot reference tyeps from hipblas headers here
735 template <>
736 struct GPUDataType<Eigen::half> {
737 #if GOOGLE_CUDA
738   static constexpr cudaDataType_t type = CUDA_R_16F;
739 #endif
740 };
741 
742 template <>
743 struct GPUDataType<float> {
744 #if GOOGLE_CUDA
745   static constexpr cudaDataType_t type = CUDA_R_32F;
746 #endif
747 };
748 
749 template <>
750 struct GPUDataType<std::complex<float>> {
751 #if GOOGLE_CUDA
752   static constexpr cudaDataType_t type = CUDA_C_32F;
753 #endif
754 };
755 
756 template <>
757 struct GPUDataType<double> {
758 #if GOOGLE_CUDA
759   static constexpr cudaDataType_t type = CUDA_R_64F;
760 #endif
761 };
762 
763 template <>
764 struct GPUDataType<std::complex<double>> {
765 #if GOOGLE_CUDA
766   static constexpr cudaDataType_t type = CUDA_C_64F;
767 #endif
768 };
769 
770 }  // namespace
771 
772 template <typename T>
773 class CSRSparseMatrixMatMul<GPUDevice, T> {
774  public:
CSRSparseMatrixMatMul(const bool transpose_output)775   explicit CSRSparseMatrixMatMul(const bool transpose_output)
776       : transpose_output_(transpose_output) {}
777 
Compute(OpKernelContext * ctx,const ConstCSRComponent<T> & a,typename TTypes<T>::UnalignedConstMatrix b,typename TTypes<T>::UnalignedMatrix c)778   Status Compute(OpKernelContext* ctx, const ConstCSRComponent<T>& a,
779                  typename TTypes<T>::UnalignedConstMatrix b,
780                  typename TTypes<T>::UnalignedMatrix c) {
781     GpuSparse cuda_sparse(ctx);
782     TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
783     {
784       // Use Csrmm/SpMM to calculate:
785       //   C = alpha * op(A) * op(B) + beta * C
786       // where alpha = 1.0, beta = 0.0, A is sparse and B and C are dense.
787       // Note that Csrmm/Spmm assumes B and C are in column-major form; so we
788       // use transB == true, and manually transpose the output in place
789       // using blas<t>geam.
790       // TODO(ebrevdo,rmlarsen): Add support for transposition and adjoint.
791 
792       // Create alpha and beta scalars; alpha = 1.0, beta = 0.0
793       // TODO(ebrevdo,rmlarsen): Add support for non-trivial alpha and beta.
794       const T alpha = 1;
795       const T beta = 0;
796 
797       // A is (m, k), Bt is (ldb, k) and Ct is (ldc, n)
798       const int k = b.dimension(0);
799       DCHECK_EQ(k, a.dense_shape_host(1));
800 
801       // If transpose_output_ is true, then the c matrix we receive
802       // here is the direct row major output (into which we will store
803       // csrgemm's col major output).  Otherwise it's a
804       // temporary tensor that will store the column major output that
805       // will eventually be transposed.
806       const int m = c.dimension(transpose_output_ ? 1 : 0);
807       const int n = c.dimension(transpose_output_ ? 0 : 1);
808       DCHECK_EQ(m, a.dense_shape_host(0));
809       DCHECK_EQ(n, b.dimension(1));
810       const int nnz = a.values.size();
811       DCHECK_EQ(nnz, a.col_ind.size());
812 
813       // ldb: leading dimension of B. If op(B)=B, it must be at least max(1, k)
814       // if op(A) = A and at least max (1, m) otherwise. If op(B) != B, it must
815       // be at least max(1, n).
816       const int ldb = n;
817       // ldc: leading dimension of C. It must be at least max(1, m) if
818       // op(A) = A and at least max(1, k) otherwise.
819       const int ldc = m;
820 
821       // transA must be non-transpose if transB is transpose (cusparse
822       // limitation).
823 #if GOOGLE_CUDA
824       const gpusparseOperation_t transA = CUSPARSE_OPERATION_NON_TRANSPOSE;
825 #elif TENSORFLOW_USE_ROCM
826       const gpusparseOperation_t transA = HIPSPARSE_OPERATION_NON_TRANSPOSE;
827 #endif
828 
829       // transB: b is row-major, and cusparse requires col-major b (or
830       // equivalently transB == transpose).  this version is actually more
831       // efficient.
832 #if GOOGLE_CUDA && CUDA_VERSION >= 10020
833 
834       const gpusparseOperation_t transB = CUSPARSE_OPERATION_TRANSPOSE;
835       gpusparseSpMatDescr_t matA;
836       gpusparseDnMatDescr_t matB, matC;
837 
838       // NOTE: the following APIs are not available in ROCM
839       TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsr(
840           &matA, m, k, nnz, const_cast<int*>(a.row_ptr.data()),
841           const_cast<int*>(a.col_ind.data()), const_cast<T*>(a.values.data()),
842           CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO,
843           GPUDataType<T>::type));
844 
845       TF_RETURN_IF_GPUSPARSE_ERROR(
846           cusparseCreateDnMat(&matB, n, k, ldb, const_cast<T*>(b.data()),
847                               GPUDataType<T>::type, CUSPARSE_ORDER_COL));
848 
849       TF_RETURN_IF_GPUSPARSE_ERROR(
850           cusparseCreateDnMat(&matC, m, n, ldc, c.data(), GPUDataType<T>::type,
851                               CUSPARSE_ORDER_COL));
852 
853       size_t bufferSize = 0;
854       TF_RETURN_IF_ERROR(cuda_sparse.SpMMBufferSize(
855           transA, transB, &alpha, matA, matB, &beta, matC,
856           CUSPARSE_MM_ALG_DEFAULT, &bufferSize));
857 
858       Tensor buffer;
859       TF_RETURN_IF_ERROR(ctx->allocate_temp(
860           DT_INT8, TensorShape({static_cast<int64>(bufferSize)}), &buffer));
861       DCHECK(buffer.flat<int8>().data() != nullptr);
862 
863       TF_RETURN_IF_ERROR(cuda_sparse.SpMM(transA, transB, &alpha, matA, matB,
864                                           &beta, matC, CUSPARSE_MM_ALG_DEFAULT,
865                                           buffer.flat<int8>().data()));
866 
867       TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnMat(matB));
868       TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnMat(matC));
869       TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroySpMat(matA));
870 
871 #else
872 
873 #if GOOGLE_CUDA
874 
875       const gpusparseOperation_t transB = CUSPARSE_OPERATION_TRANSPOSE;
876 
877       gpusparseMatDescr_t descrA;
878       TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA));
879       TF_RETURN_IF_GPUSPARSE_ERROR(
880           cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL));
881       TF_RETURN_IF_GPUSPARSE_ERROR(
882           cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO));
883 
884 #elif TENSORFLOW_USE_ROCM
885 
886       const gpusparseOperation_t transB = HIPSPARSE_OPERATION_TRANSPOSE;
887 
888       gpusparseMatDescr_t descrA;
889       TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateMatDescr(&descrA));
890       TF_RETURN_IF_GPUSPARSE_ERROR(
891           wrap::hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL));
892       TF_RETURN_IF_GPUSPARSE_ERROR(
893           wrap::hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
894 #endif  // GOOGLE_CUDA
895 
896       TF_RETURN_IF_ERROR(
897           cuda_sparse.Csrmm(transA, transB, m, n, k, nnz, &alpha, descrA,
898                             a.values.data(), a.row_ptr.data(), a.col_ind.data(),
899                             b.data(), ldb, &beta, c.data(), ldc));
900 
901 #endif  // GOOGLE_CUDA && CUDA_VERSION >= 10020
902     }
903 
904     return Status::OK();
905   }
906 
907  private:
908   bool transpose_output_;
909 };
910 
911 template <typename T>
912 class CSRSparseMatrixMatVec<GPUDevice, T> {
913  public:
CSRSparseMatrixMatVec(bool transpose_a,bool conjugate_a)914   CSRSparseMatrixMatVec(bool transpose_a, bool conjugate_a)
915       : transA_(TransposeAndConjugateToGpuSparseOp(transpose_a, conjugate_a,
916                                                    &status_)) {}
917 
Compute(OpKernelContext * ctx,const ConstCSRComponent<T> & a,const T * x,T * y)918   Status Compute(OpKernelContext* ctx, const ConstCSRComponent<T>& a,
919                  const T* x, T* y) {
920     TF_RETURN_IF_ERROR(status_);
921     GpuSparse cuda_sparse(ctx);
922     TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
923     {
924       // Use Csrmv to calculate:
925       //   y = alpha * op(A) * x + beta * y
926       // where alpha = 1.0, beta = 0.0, A is a sparse matrix and x and y are
927       // dense vectors.
928 
929       // Create alpha and beta scalars; alpha = 1.0, beta = 0.0
930       // TODO(rmlarsen,ebrevdo): Add support for general alpha, beta.
931       const T alpha = 1;
932       const T beta = 0;
933 
934 #if GOOGLE_CUDA && CUDA_VERSION < 10020
935       gpusparseMatDescr_t descrA;
936       TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA));
937       TF_RETURN_IF_GPUSPARSE_ERROR(
938           cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL));
939       TF_RETURN_IF_GPUSPARSE_ERROR(
940           cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO));
941 #elif TENSORFLOW_USE_ROCM
942       gpusparseMatDescr_t descrA;
943       TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateMatDescr(&descrA));
944       TF_RETURN_IF_GPUSPARSE_ERROR(
945           wrap::hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL));
946       TF_RETURN_IF_GPUSPARSE_ERROR(
947           wrap::hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
948 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
949 
950       const int m = a.dense_shape_host(0);
951       const int n = a.dense_shape_host(1);
952       const int nnz = a.values.size();
953       DCHECK_EQ(nnz, a.col_ind.size());
954 #if GOOGLE_CUDA && (CUDA_VERSION >= 10020)
955       TF_RETURN_IF_ERROR(cuda_sparse.Csrmv(transA_, m, n, nnz, &alpha,
956                                            a.values.data(), a.row_ptr.data(),
957                                            a.col_ind.data(), x, &beta, y));
958 #else
959       TF_RETURN_IF_ERROR(cuda_sparse.Csrmv(transA_, m, n, nnz, &alpha, descrA,
960                                            a.values.data(), a.row_ptr.data(),
961                                            a.col_ind.data(), x, &beta, y));
962 #endif
963     }
964 
965     return Status::OK();
966   }
967 
968  private:
969   Status status_;
970   const gpusparseOperation_t transA_;
971 };
972 
973 }  // namespace functor
974 
975 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
976 
977 }  // namespace tensorflow
978