• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
19 #define EIGEN_USE_GPU
20 #endif
21 
22 #include <memory>
23 #include <numeric>
24 
25 #include "third_party/eigen3/Eigen/SparseCore"
26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27 #include "tensorflow/core/framework/op.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 #include "tensorflow/core/framework/tensor_types.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/framework/variant_op_registry.h"
33 #include "tensorflow/core/kernels/dense_update_functor.h"
34 #include "tensorflow/core/kernels/sparse/kernels.h"
35 #include "tensorflow/core/kernels/sparse/sparse_matrix.h"
36 #include "tensorflow/core/util/work_sharder.h"
37 
38 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
39 #include "tensorflow/core/util/cuda_solvers.h"
40 #include "tensorflow/core/util/cuda_sparse.h"
41 #endif
42 
43 namespace tensorflow {
44 
45 typedef Eigen::ThreadPoolDevice CPUDevice;
46 typedef Eigen::GpuDevice GPUDevice;
47 
48 namespace {
49 
50 // Swaps the dim sizes at two given dimensions of a TensorShape.
51 // Callers are responsible for making sure the given dimensions are within the
52 // valid dimension range of the TensorShape.
SwapDimSizes(const int dim_a,const int dim_b,TensorShape * shape)53 void SwapDimSizes(const int dim_a, const int dim_b, TensorShape* shape) {
54   const int64 size_a = shape->dim_size(dim_a);
55   const int64 size_b = shape->dim_size(dim_b);
56   shape->set_dim(dim_a, size_b);
57   shape->set_dim(dim_b, size_a);
58 }
59 
60 }  // namespace
61 
62 // Op to compute the matrix multiplication of two CSR Sparse Matrices.
63 //
64 // Implements a CPU kernel to perform matrix multiplication using Eigen
65 // SparseMatrix and its Sparse-Sparse matmul. Supports transposing and
66 // adjointing on the fly for both the inputs without actually constructing the
67 // transpose or adjoint.
68 //
69 // This implementation does not support broadcasting. Hence both the input
70 // CSRSparseMatrices must have the same rank. (Either rank 2 or rank 3).
71 //
72 // The output sparse have numeric (non-structural) zeros.
73 // TODO(anudhyan): Consider exposing whether to prune zeros as an attribute in
74 // the op's interface.
75 //
76 // If multiple threads are available, we parallelize across multiple batches
77 // using Eigen ThreadPool. Within a single batch, we run in single threaded mode
78 // because Eigen's Sparse-Sparse matmul doesn't support multithreading.
79 //
80 // TODO(b/126472741): Due to the multiple batches of a 3D CSRSparseMatrix being
81 // laid out in contiguous memory, this implementation allocates memory to store
82 // a temporary copy of the matrix product. Consequently, it uses roughly twice
83 // the amount of memory that it needs to. This may cause a memory blowup for
84 // sparse matrices with a high number of non-zero elements.
85 template <typename T>
86 class CSRSparseMatMulCPUOp : public OpKernel {
87   using SparseMatrix = Eigen::SparseMatrix<T, Eigen::RowMajor>;
88 
89  public:
CSRSparseMatMulCPUOp(OpKernelConstruction * c)90   explicit CSRSparseMatMulCPUOp(OpKernelConstruction* c) : OpKernel(c) {
91     OP_REQUIRES_OK(c, c->GetAttr("transpose_a", &transpose_a_));
92     OP_REQUIRES_OK(c, c->GetAttr("transpose_b", &transpose_b_));
93     OP_REQUIRES_OK(c, c->GetAttr("adjoint_a", &adjoint_a_));
94     OP_REQUIRES(c, !(adjoint_a_ && transpose_a_),
95                 errors::InvalidArgument(
96                     "Only one of adjoint_a and transpose_a may be true."));
97     OP_REQUIRES_OK(c, c->GetAttr("adjoint_b", &adjoint_b_));
98     OP_REQUIRES(c, !(adjoint_b_ && transpose_b_),
99                 errors::InvalidArgument(
100                     "Only one of adjoint_b and transpose_b may be true."));
101   }
102 
Compute(OpKernelContext * ctx)103   void Compute(OpKernelContext* ctx) final {
104     const CSRSparseMatrix* input_matrix_a;
105     const CSRSparseMatrix* input_matrix_b;
106     // TODO(anudhyan): Factor out common validation logic in CPU and GPU Ops
107     // into a common base class.
108     OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &input_matrix_a));
109     OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 1, &input_matrix_b));
110     OP_REQUIRES(ctx, input_matrix_a->dtype() == DataTypeToEnum<T>::value,
111                 errors::InvalidArgument(
112                     "dtype of a is not equal to 'type': ",
113                     DataTypeString(input_matrix_a->dtype()), " vs. ",
114                     DataTypeString(DataTypeToEnum<T>::value)));
115     OP_REQUIRES(ctx, input_matrix_b->dtype() == DataTypeToEnum<T>::value,
116                 errors::InvalidArgument(
117                     "dtype of b is not equal to 'type': ",
118                     DataTypeString(input_matrix_b->dtype()), " vs. ",
119                     DataTypeString(DataTypeToEnum<T>::value)));
120     OP_REQUIRES(ctx,
121                 input_matrix_a->batch_size() == input_matrix_b->batch_size(),
122                 errors::InvalidArgument(
123                     "Batch sizes of A and B do not agree.  Batch sizes are: ",
124                     input_matrix_a->batch_size(), " vs. ",
125                     input_matrix_b->batch_size()));
126 
127     // Validate input_matrix_a's and input_matrix_b's shapes
128     TensorShape a_shape;
129     TensorShape b_shape;
130     OP_REQUIRES_OK(ctx,
131                    TensorShapeUtils::MakeShape(
132                        input_matrix_a->dense_shape().vec<int64>(), &a_shape));
133     OP_REQUIRES_OK(ctx,
134                    TensorShapeUtils::MakeShape(
135                        input_matrix_b->dense_shape().vec<int64>(), &b_shape));
136 
137     const int rank = a_shape.dims();
138     const int row_dim = (rank == 2) ? 0 : 1;
139     if (transpose_a_ || adjoint_a_)
140       SwapDimSizes(row_dim, row_dim + 1, &a_shape);
141     if (transpose_b_ || adjoint_b_)
142       SwapDimSizes(row_dim, row_dim + 1, &b_shape);
143 
144     OP_REQUIRES(
145         ctx, a_shape.dim_size(row_dim + 1) == b_shape.dim_size(row_dim),
146         errors::InvalidArgument(
147             "Inner product dimensions of A and B do not agree.  Shapes are: ",
148             a_shape.DebugString(), " vs. ", b_shape.DebugString()));
149 
150     // Infer the output shape of the matrix product.
151     // TODO(ebrevdo): MatMul support for broadcasting at least in the
152     // batch dimension.
153     const int batch_size = input_matrix_a->batch_size();
154     Tensor output_shape(cpu_allocator(), DT_INT64, TensorShape({rank}));
155     auto output_shape_vec = output_shape.vec<int64>();
156     if (rank == 3) output_shape_vec(0) = batch_size;
157     output_shape_vec(row_dim) = a_shape.dim_size(row_dim);
158     output_shape_vec(row_dim + 1) = b_shape.dim_size(row_dim + 1);
159 
160     // Set batch pointers.
161     Tensor batch_ptr(cpu_allocator(), DT_INT32, TensorShape({batch_size + 1}));
162     auto batch_ptr_vec = batch_ptr.vec<int32>();
163     batch_ptr_vec(0) = 0;
164 
165     // Store intermediate matrix products for each batch.
166     // TODO(b/126472741): For a single batch, consider reusing the
167     // SparseMatrices' buffers to construct the CSRSparseMatrix to prevent 2x
168     // memory usage.
169     std::vector<SparseMatrix> output_matrices(batch_size);
170 
171     auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
172     // Estimate the cost per batch per as num_output_rows times the product of
173     // average number of nonzeros per row.
174     const int64 num_output_rows = output_shape_vec(row_dim);
175     const double avg_nnz_per_row_a =
176         input_matrix_a->total_nnz() /
177         static_cast<double>(a_shape.dim_size(row_dim) * batch_size);
178     const double avg_nnz_per_row_b =
179         input_matrix_b->total_nnz() /
180         static_cast<double>(b_shape.dim_size(row_dim) * batch_size);
181     const int64 matmul_cost_per_batch =
182         num_output_rows * (avg_nnz_per_row_a * avg_nnz_per_row_b);
183 
184     // Parallelize matrix multiplication across batches.
185     Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
186           matmul_cost_per_batch, [&](int64 batch_begin, int64 batch_end) {
187             for (int64 batch_idx = batch_begin; batch_idx < batch_end;
188                  ++batch_idx) {
189               // For each batch, map the CSRSparseMatrix as Eigen SparseMatrix
190               // without copying the underlying data.
191               auto a_ref = GetSparseMatrixRef(*input_matrix_a, rank, batch_idx,
192                                               transpose_a_, adjoint_a_);
193               auto b_ref = GetSparseMatrixRef(*input_matrix_b, rank, batch_idx,
194                                               transpose_b_, adjoint_b_);
195 
196               // Matrix multiply while *not* pruning numerical zeros on the fly.
197               // Allocates output SparseMatrix and moves it to our list of
198               // output_matrices.
199               output_matrices[batch_idx] = a_ref * b_ref;
200 
201               // For now, batch_ptr contains the number of nonzeros in each
202               // batch.
203               batch_ptr_vec(batch_idx + 1) =
204                   output_matrices[batch_idx].nonZeros();
205             }
206           });
207 
208     // Compute the cumulative sum to obtain the batch pointers.
209     std::partial_sum(batch_ptr_vec.data(),
210                      batch_ptr_vec.data() + batch_size + 1,
211                      batch_ptr_vec.data());
212     const int64 total_nnz = batch_ptr_vec(batch_size);
213 
214     // Allocate output tensors.
215     Tensor output_row_ptr(cpu_allocator(), DT_INT32,
216                           TensorShape({(num_output_rows + 1) * batch_size}));
217     Tensor output_col_ind(cpu_allocator(), DT_INT32, TensorShape({total_nnz}));
218     Tensor output_values(cpu_allocator(), DataTypeToEnum<T>::value,
219                          TensorShape({total_nnz}));
220     auto output_row_ptr_ptr = output_row_ptr.flat<int32>().data();
221     auto output_col_ind_ptr = output_col_ind.flat<int32>().data();
222     auto output_values_ptr = output_values.flat<T>().data();
223 
224     // Copy the output matrices from each batch into the CSRSparseMatrix
225     // tensors.
226     Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
227           (3 * total_nnz) / batch_size /* cost per unit */,
228           [&](int64 batch_begin, int64 batch_end) {
229             for (int64 batch_idx = batch_begin; batch_idx < batch_end;
230                  ++batch_idx) {
231               const SparseMatrix& output_matrix = output_matrices[batch_idx];
232               const int64 nnz = output_matrix.nonZeros();
233               std::copy(output_matrix.outerIndexPtr(),
234                         output_matrix.outerIndexPtr() + num_output_rows + 1,
235                         output_row_ptr_ptr + batch_idx * (num_output_rows + 1));
236               std::copy(output_matrix.innerIndexPtr(),
237                         output_matrix.innerIndexPtr() + nnz,
238                         output_col_ind_ptr + batch_ptr_vec(batch_idx));
239               std::copy(output_matrix.valuePtr(),
240                         output_matrix.valuePtr() + nnz,
241                         output_values_ptr + batch_ptr_vec(batch_idx));
242             }
243           });
244 
245     // Create the CSRSparseMatrix object from its component Tensors and prepare
246     // the Variant output Tensor.
247     CSRSparseMatrix output_csr_matrix;
248     OP_REQUIRES_OK(ctx, CSRSparseMatrix::CreateCSRSparseMatrix(
249                             DataTypeToEnum<T>::value, output_shape, batch_ptr,
250                             output_row_ptr, output_col_ind, output_values,
251                             &output_csr_matrix));
252     Tensor* output_csr_matrix_tensor;
253     AllocatorAttributes cpu_alloc;
254     cpu_alloc.set_on_host(true);
255     OP_REQUIRES_OK(
256         ctx, ctx->allocate_output(0, TensorShape({}), &output_csr_matrix_tensor,
257                                   cpu_alloc));
258     output_csr_matrix_tensor->scalar<Variant>()() =
259         std::move(output_csr_matrix);
260   }
261 
262  private:
263   // Returns an Eigen::Ref expression of a SparseMatrix; which points to the
264   // underlying memory of the given CSRSparseMatrix.
GetSparseMatrixRef(const CSRSparseMatrix & csr_matrix,const int rank,const int batch_index,const bool transpose,const bool adjoint)265   Eigen::Ref<const SparseMatrix> GetSparseMatrixRef(
266       const CSRSparseMatrix& csr_matrix, const int rank, const int batch_index,
267       const bool transpose, const bool adjoint) {
268     const auto dense_shape = csr_matrix.dense_shape().vec<int64>();
269     const int64 num_rows = dense_shape(rank == 2 ? 0 : 1);
270     const int64 num_cols = dense_shape(rank == 2 ? 1 : 2);
271 
272     Eigen::Map<const SparseMatrix> sparse_matrix(
273         num_rows, num_cols, csr_matrix.nnz(batch_index),
274         csr_matrix.row_pointers_vec(batch_index).data(),
275         csr_matrix.col_indices_vec(batch_index).data(),
276         csr_matrix.values_vec<T>(batch_index).data());
277 
278     // The transpose/adjoint expressions are not actually evaluated until
279     // necessary. Hence we don't create copies or modify the input matrix
280     // inplace.
281     if (transpose) return sparse_matrix.transpose();
282     if (adjoint) return sparse_matrix.adjoint();
283     return sparse_matrix;
284   }
285 
286   bool transpose_a_;
287   bool transpose_b_;
288   bool adjoint_a_;
289   bool adjoint_b_;
290 };
291 
292 template <typename Device, typename T>
293 class CSRSparseMatMulGPUOp : public OpKernel {
294  public:
CSRSparseMatMulGPUOp(OpKernelConstruction * c)295   explicit CSRSparseMatMulGPUOp(OpKernelConstruction* c) : OpKernel(c) {
296     OP_REQUIRES_OK(c, c->GetAttr("transpose_a", &transpose_a_));
297     OP_REQUIRES_OK(c, c->GetAttr("transpose_b", &transpose_b_));
298     bool adjoint_a;
299     OP_REQUIRES_OK(c, c->GetAttr("adjoint_a", &adjoint_a));
300     OP_REQUIRES(c, !(adjoint_a && transpose_a_),
301                 errors::InvalidArgument(
302                     "Only one of adjoint_a and transpose_a may be true."));
303     bool adjoint_b;
304     OP_REQUIRES_OK(c, c->GetAttr("adjoint_b", &adjoint_b));
305     OP_REQUIRES(c, !(adjoint_b && transpose_b_),
306                 errors::InvalidArgument(
307                     "Only one of adjoint_b and transpose_b may be true."));
308     conjugate_a_ = adjoint_a;
309     conjugate_b_ = adjoint_b;
310     transpose_a_ = transpose_a_ || adjoint_a;
311     transpose_b_ = transpose_b_ || adjoint_b;
312   }
313 
Compute(OpKernelContext * ctx)314   void Compute(OpKernelContext* ctx) final {
315     const CSRSparseMatrix* a_matrix;
316     const CSRSparseMatrix* b_matrix;
317     OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &a_matrix));
318     OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 1, &b_matrix));
319     OP_REQUIRES(
320         ctx, a_matrix->dtype() == DataTypeToEnum<T>::value,
321         errors::InvalidArgument("dtype of a is not equal to 'type': ",
322                                 DataTypeString(a_matrix->dtype()), " vs. ",
323                                 DataTypeString(DataTypeToEnum<T>::value)));
324     OP_REQUIRES(
325         ctx, b_matrix->dtype() == DataTypeToEnum<T>::value,
326         errors::InvalidArgument("dtype of b is not equal to 'type': ",
327                                 DataTypeString(b_matrix->dtype()), " vs. ",
328                                 DataTypeString(DataTypeToEnum<T>::value)));
329 
330     // TODO(ebrevdo): MatMul support for broadcasting at least in the
331     // batch dimension.
332     auto a_dense_shape = a_matrix->dense_shape().vec<int64>();
333     auto b_dense_shape = b_matrix->dense_shape().vec<int64>();
334 
335     TensorShape a_tensor_shape;
336     TensorShape b_tensor_shape;
337     OP_REQUIRES_OK(ctx,
338                    TensorShapeUtils::MakeShape(a_dense_shape, &a_tensor_shape));
339     OP_REQUIRES_OK(ctx,
340                    TensorShapeUtils::MakeShape(b_dense_shape, &b_tensor_shape));
341 
342     const int rank = a_tensor_shape.dims();
343     const int row_dim = (rank == 2) ? 0 : 1;
344 
345     const int64 a_inner_dim =
346         a_tensor_shape.dim_size(transpose_a_ ? row_dim : row_dim + 1);
347     const int64 b_inner_dim =
348         b_tensor_shape.dim_size(transpose_b_ ? row_dim + 1 : row_dim);
349 
350     const int batch_size = a_matrix->batch_size();
351 
352     OP_REQUIRES(
353         ctx, a_inner_dim == b_inner_dim,
354         errors::InvalidArgument(
355             "Inner product dimensions of A and B do not agree.  Shapes are: ",
356             a_tensor_shape.DebugString(), " vs. ",
357             b_tensor_shape.DebugString()));
358 
359     Tensor c_dense_shape_t(cpu_allocator(), DT_INT64, TensorShape({rank}));
360     auto c_dense_shape = c_dense_shape_t.vec<int64>();
361 
362     if (rank == 3) c_dense_shape(0) = batch_size;
363     c_dense_shape(row_dim) =
364         a_tensor_shape.dim_size(transpose_a_ ? row_dim + 1 : row_dim);
365     c_dense_shape(row_dim + 1) =
366         b_tensor_shape.dim_size(transpose_b_ ? row_dim : row_dim + 1);
367 
368     const int64 rows = c_dense_shape((rank == 2) ? 0 : 1);
369 
370     CSRSparseMatrix c;
371     Tensor c_row_ptrs;
372 
373     // TODO(ebrevdo): Re-enable transposing within the GEMM kernel when cuSparse
374     // stops spitting out CUSPARSE_STATUS_INTERNAL_ERROR values for transposes.
375     functor::CSRSparseSparseMatrixMatMul<Device, T> csr_gemm(
376         ctx, /*transpose_a=*/false, /*adjoint_a=*/false, /*transpose_b=*/false);
377     OP_REQUIRES_OK(ctx, csr_gemm.Initialize());
378 
379     Tensor c_batch_ptr_t(cpu_allocator(), DT_INT32,
380                          TensorShape({batch_size + 1}));
381     auto c_batch_ptr = c_batch_ptr_t.vec<int32>();
382     c_batch_ptr(0) = 0;
383 
384     Tensor c_row_ptr_t;
385     OP_REQUIRES_OK(ctx, ctx->allocate_temp(
386                             DT_INT32, TensorShape({batch_size * (rows + 1)}),
387                             &c_row_ptr_t));
388     auto c_row_ptr = c_row_ptr_t.vec<int32>();
389 
390     // Possibly transpose a.
391     const CSRSparseMatrix* a_input_matrix;
392     // If we need to transpose a, we will store the result temporarily
393     // in the object below.
394     CSRSparseMatrix a_matrix_transposed;
395     if (!transpose_a_) {
396       a_input_matrix = a_matrix;
397     } else {
398       functor::CSRSparseMatrixTranspose<Device, T> transpose;
399       OP_REQUIRES_OK(
400           ctx, transpose(ctx, conjugate_a_, *a_matrix, &a_matrix_transposed));
401       a_input_matrix = &a_matrix_transposed;
402     }
403     auto a_input_dense_shape = a_input_matrix->dense_shape().vec<int64>();
404 
405     // Possibly transpose b.
406     const CSRSparseMatrix* b_input_matrix;
407     // If we need to transpose a, we will store the result temporarily
408     // in the object below.
409     CSRSparseMatrix b_matrix_transposed;
410     if (!transpose_b_) {
411       b_input_matrix = b_matrix;
412     } else {
413       functor::CSRSparseMatrixTranspose<Device, T> transpose;
414       OP_REQUIRES_OK(
415           ctx, transpose(ctx, conjugate_b_, *b_matrix, &b_matrix_transposed));
416       b_input_matrix = &b_matrix_transposed;
417     }
418     auto b_input_dense_shape = b_input_matrix->dense_shape().vec<int64>();
419 
420 #if GOOGLE_CUDA && (CUDA_VERSION >= 10000)
421     size_t maxWorkspaceSize = 0;
422     for (int i = 0; i < batch_size; ++i) {
423       // Calculate maximum workspace size over batch.
424       ConstCSRComponent<T> a_comp{a_input_matrix->row_pointers_vec(i),
425                                   a_input_matrix->col_indices_vec(i),
426                                   a_input_matrix->values_vec<T>(i),
427                                   a_input_dense_shape};
428       ConstCSRComponent<T> b_comp{b_input_matrix->row_pointers_vec(i),
429                                   b_input_matrix->col_indices_vec(i),
430                                   b_input_matrix->values_vec<T>(i),
431                                   b_input_dense_shape};
432       size_t thisWorkspaceSize;
433       OP_REQUIRES_OK(
434           ctx, csr_gemm.GetWorkspaceSize(a_comp, b_comp, &thisWorkspaceSize));
435       if (thisWorkspaceSize > maxWorkspaceSize) {
436         maxWorkspaceSize = thisWorkspaceSize;
437       }
438     }
439 
440     Tensor temp;
441     OP_REQUIRES_OK(
442         ctx, ctx->allocate_temp(
443                  DT_INT8, TensorShape({static_cast<int64>(maxWorkspaceSize)}),
444                  &temp));
445     void* workspace = temp.flat<int8>().data();
446 #else
447     void* workspace = nullptr;
448 #endif
449 
450     for (int i = 0; i < batch_size; ++i) {
451       // Calculate output sizes for all minibatch entries.
452       // Store in c_batch_ptr and update c_row_ptrs.
453       ConstCSRComponent<T> a_comp{a_input_matrix->row_pointers_vec(i),
454                                   a_input_matrix->col_indices_vec(i),
455                                   a_input_matrix->values_vec<T>(i),
456                                   a_input_dense_shape};
457       ConstCSRComponent<T> b_comp{b_input_matrix->row_pointers_vec(i),
458                                   b_input_matrix->col_indices_vec(i),
459                                   b_input_matrix->values_vec<T>(i),
460                                   b_input_dense_shape};
461 
462       TTypes<int32>::UnalignedVec c_row_ptr_i(&c_row_ptr(i * (rows + 1)),
463                                               rows + 1);
464 
465       int c_nnz_i;
466       OP_REQUIRES_OK(ctx,
467                      csr_gemm.GetOutputStructure(a_comp, b_comp, c_row_ptr_i,
468                                                  &c_nnz_i, workspace));
469       c_batch_ptr(i + 1) = c_batch_ptr(i) + c_nnz_i;
470     }
471 
472     Tensor c_col_ind_t;
473     Tensor c_values_t;
474 
475     const int total_nnz = c_batch_ptr(batch_size);
476 
477     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_INT32, TensorShape({total_nnz}),
478                                            &c_col_ind_t));
479     OP_REQUIRES_OK(ctx,
480                    ctx->allocate_temp(DataTypeToEnum<T>::value,
481                                       TensorShape({total_nnz}), &c_values_t));
482     OP_REQUIRES_OK(ctx,
483                    CSRSparseMatrix::CreateCSRSparseMatrix(
484                        DataTypeToEnum<T>::value, c_dense_shape_t, c_batch_ptr_t,
485                        c_row_ptr_t, c_col_ind_t, c_values_t, &c));
486 
487     for (int i = 0; i < batch_size; ++i) {
488       ConstCSRComponent<T> a_comp{a_input_matrix->row_pointers_vec(i),
489                                   a_input_matrix->col_indices_vec(i),
490                                   a_input_matrix->values_vec<T>(i),
491                                   a_input_dense_shape};
492       ConstCSRComponent<T> b_comp{b_input_matrix->row_pointers_vec(i),
493                                   b_input_matrix->col_indices_vec(i),
494                                   b_input_matrix->values_vec<T>(i),
495                                   b_input_dense_shape};
496       CSRComponent<T> c_comp{c.row_pointers_vec(i), c.col_indices_vec(i),
497                              c.values_vec<T>(i), c_dense_shape};
498       OP_REQUIRES_OK(ctx, csr_gemm.Compute(a_comp, b_comp, &c_comp, workspace));
499     }
500 
501     Tensor c_t(cpu_allocator(), DT_VARIANT, TensorShape({}));
502     c_t.scalar<Variant>()() = std::move(c);
503     ctx->set_output(0, c_t);
504   }
505 
506  private:
507   bool transpose_a_;
508   bool transpose_b_;
509   bool conjugate_a_;
510   bool conjugate_b_;
511 };
512 
513 #define REGISTER_CPU(T)                                    \
514   REGISTER_KERNEL_BUILDER(Name("SparseMatrixSparseMatMul") \
515                               .Device(DEVICE_CPU)          \
516                               .TypeConstraint<T>("type"),  \
517                           CSRSparseMatMulCPUOp<T>);
518 
519 REGISTER_CPU(float)
520 REGISTER_CPU(double)
521 REGISTER_CPU(complex64)
522 REGISTER_CPU(complex128)
523 
524 #undef REGISTER_CPU
525 
526 #define REGISTER(DEV, T)                                   \
527   REGISTER_KERNEL_BUILDER(Name("SparseMatrixSparseMatMul") \
528                               .Device(DEVICE_##DEV)        \
529                               .TypeConstraint<T>("type"),  \
530                           CSRSparseMatMulGPUOp<DEV##Device, T>);
531 
532 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
533 
534 #define REGISTER_GPU(T) REGISTER(GPU, T)
535 
536 REGISTER_GPU(float)
537 REGISTER_GPU(double)
538 #if GOOGLE_CUDA
539 REGISTER_GPU(complex64)
540 REGISTER_GPU(complex128)
541 #endif  // GOOGLE_CUDA
542 
543 #undef REGISTER_GPU
544 
545 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
546 
547 #undef REGISTER
548 
549 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
550 namespace functor {
551 template <typename T>
552 struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
553     : public CSRStructureModifyingFunctor<GPUDevice, T> {
CSRSparseSparseMatrixMatMultensorflow::functor::CSRSparseSparseMatrixMatMul554   explicit CSRSparseSparseMatrixMatMul(OpKernelContext* ctx, bool transpose_a,
555                                        bool adjoint_a, bool transpose_b)
556       : ctx_(ctx),
557         cuda_sparse_(ctx),
558         initialized_(false),
559         transpose_a_(transpose_a),
560         adjoint_a_(adjoint_a),
561 #if (GOOGLE_CUDA && (CUDA_VERSION < 10000)) || TENSORFLOW_USE_ROCM
562         transpose_b_(transpose_b) {
563 #else
564         transpose_b_(transpose_b),
565         info_(nullptr) {
566 #endif  // CUDA_VERSION < 10000
567     // TODO(ebrevdo): Figure out why transposed implementations crash cuSparse.
568     transA_ = transpose_a
569                   ? (adjoint_a ? GPUSPARSE(OPERATION_TRANSPOSE)
570                                : GPUSPARSE(OPERATION_CONJUGATE_TRANSPOSE))
571                   : GPUSPARSE(OPERATION_NON_TRANSPOSE);
572     transB_ = transpose_b ? GPUSPARSE(OPERATION_TRANSPOSE)
573                           : GPUSPARSE(OPERATION_NON_TRANSPOSE);
574   }
575 
576 #if GOOGLE_CUDA && (CUDA_VERSION >= 10000)
577   ~CSRSparseSparseMatrixMatMul() {
578     if (initialized_) {
579       cusparseDestroyCsrgemm2Info(info_);
580     }
581   }
582 #endif
583 
584   Status Initialize() {
585     if (adjoint_a_ && transpose_a_) {
586       return errors::InvalidArgument(
587           "Only one of adjoint_a and transpose_a may be true.");
588     }
589 
590     TF_RETURN_IF_ERROR(cuda_sparse_.Initialize());
591     TF_RETURN_IF_ERROR(descrA_.Initialize());
592     TF_RETURN_IF_ERROR(descrB_.Initialize());
593     TF_RETURN_IF_ERROR(descrC_.Initialize());
594 #if GOOGLE_CUDA && (CUDA_VERSION >= 10000)
595     TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsrgemm2Info(&info_));
596 #endif
597     initialized_ = true;
598     return Status::OK();
599   }
600 
601   Status GetWorkspaceSize(const ConstCSRComponent<T>& a,
602                           const ConstCSRComponent<T>& b, size_t* bufferSize) {
603 #if GOOGLE_CUDA && (CUDA_VERSION >= 10000)
604     DCHECK(initialized_);
605     const int m =
606         a.dense_shape_host(a.dense_shape_host.size() - (transpose_a_ ? 1 : 2));
607     if (!transpose_a_) {
608       DCHECK_EQ(m, a.row_ptr.size() - 1);
609     }
610     const int k =
611         a.dense_shape_host(a.dense_shape_host.size() - (transpose_a_ ? 2 : 1));
612     if (!transpose_b_) {
613       DCHECK_EQ(k, b.row_ptr.size() - 1);
614     }
615     const int nnzA = a.col_ind.size();
616     const int nnzB = b.col_ind.size();
617 
618     const int n =
619         b.dense_shape_host(b.dense_shape_host.size() - (transpose_b_ ? 2 : 1));
620 
621     TF_RETURN_IF_ERROR(cuda_sparse_.CsrgemmBufferSize<T>(
622         m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(), a.col_ind.data(),
623         descrB_.descr(), nnzB, b.row_ptr.data(), b.col_ind.data(), info_,
624         bufferSize));
625 #endif
626 
627     return Status::OK();
628   }
629 
630   Status GetOutputStructure(const ConstCSRComponent<T>& a,
631                             const ConstCSRComponent<T>& b,
632                             TTypes<int32>::UnalignedVec c_row_ptr,
633                             int* output_nnz, void* workspace) {
634     DCHECK(initialized_);
635 
636     const int m =
637         a.dense_shape_host(a.dense_shape_host.size() - (transpose_a_ ? 1 : 2));
638     if (!transpose_a_) {
639       DCHECK_EQ(m, a.row_ptr.size() - 1);
640     }
641     DCHECK_EQ(m, c_row_ptr.size() - 1);
642     const int k =
643         a.dense_shape_host(a.dense_shape_host.size() - (transpose_a_ ? 2 : 1));
644     if (!transpose_b_) {
645       DCHECK_EQ(k, b.row_ptr.size() - 1);
646     }
647     const int nnzA = a.col_ind.size();
648     const int nnzB = b.col_ind.size();
649 
650     const int n =
651         b.dense_shape_host(b.dense_shape_host.size() - (transpose_b_ ? 2 : 1));
652 
653     *output_nnz = -1;
654 
655 #if (GOOGLE_CUDA && (CUDA_VERSION < 10000)) || TENSORFLOW_USE_ROCM
656     TF_RETURN_IF_ERROR(cuda_sparse_.CsrgemmNnz(
657         transA_, transB_, m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(),
658         a.col_ind.data(), descrB_.descr(), nnzB, b.row_ptr.data(),
659         b.col_ind.data(), descrC_.descr(), c_row_ptr.data(), output_nnz));
660 #else
661     TF_RETURN_IF_ERROR(cuda_sparse_.CsrgemmNnz(
662         m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(), a.col_ind.data(),
663         descrB_.descr(), nnzB, b.row_ptr.data(), b.col_ind.data(),
664         descrC_.descr(), c_row_ptr.data(), output_nnz, info_, workspace));
665 #endif
666 
667     if (*output_nnz < 0) {
668       return errors::Internal(
669           "CSRMatMul: CsrgemmNnz returned nnzTotalDevHostPtr < 0: ",
670           *output_nnz);
671     }
672     return Status::OK();
673   }
674 
675   Status Compute(const ConstCSRComponent<T>& a, const ConstCSRComponent<T>& b,
676                  CSRComponent<T>* c, void* workspace) {
677     DCHECK(initialized_);
678 
679     const int m =
680         a.dense_shape_host(a.dense_shape_host.size() - (transpose_a_ ? 1 : 2));
681     if (!transpose_a_) {
682       DCHECK_EQ(m, a.row_ptr.size() - 1);
683     }
684     DCHECK_EQ(m, c->dense_shape_host(c->dense_shape_host.size() - 2));
685     DCHECK_EQ(m, c->row_ptr.size() - 1);
686     const int k =
687         a.dense_shape_host(a.dense_shape_host.size() - (transpose_a_ ? 2 : 1));
688     if (!transpose_b_) {
689       DCHECK_EQ(k, b.row_ptr.size() - 1);
690     }
691     const int nnzA = a.col_ind.size();
692     const int nnzB = b.col_ind.size();
693 
694     const int n =
695         b.dense_shape_host(b.dense_shape_host.size() - (transpose_b_ ? 2 : 1));
696     DCHECK_EQ(n, c->dense_shape_host(c->dense_shape_host.size() - 1));
697 
698 #if (GOOGLE_CUDA && (CUDA_VERSION < 10000)) || TENSORFLOW_USE_ROCM
699     TF_RETURN_IF_ERROR(cuda_sparse_.Csrgemm(
700         transA_, transB_, m, k, n, descrA_.descr(), nnzA, a.values.data(),
701         a.row_ptr.data(), a.col_ind.data(), descrB_.descr(), nnzB,
702         b.values.data(), b.row_ptr.data(), b.col_ind.data(), descrC_.descr(),
703         c->values.data(), c->row_ptr.data(), c->col_ind.data()));
704 #else
705     TF_RETURN_IF_ERROR(cuda_sparse_.Csrgemm(
706         m, n, k, descrA_.descr(), nnzA, a.values.data(), a.row_ptr.data(),
707         a.col_ind.data(), descrB_.descr(), nnzB, b.values.data(),
708         b.row_ptr.data(), b.col_ind.data(), descrC_.descr(), c->values.data(),
709         c->row_ptr.data(), c->col_ind.data(), info_, workspace));
710 #endif
711 
712     // TODO(ebrevdo): Add a flag to CSRSparseMatrix whether matrix
713     // columns are sorted?  Above operation leads to unsorted columns.
714     // For now, here is an example of how to ensure the output columns
715     // are sorted.  Leaving here in case we realize we need to ensure
716     // sorted columns in the future.
717     //
718     // TF_RETURN_IF_ERROR(cuda_sparse.Csru2csr(
719     //     m, n, nnzTotalDevHostPtr, descrA_.descr(), c->values.data(),
720     //     c->row_ptr.data(), c->col_ind.data()));
721 
722     return Status::OK();
723   }
724 
725  private:
726   OpKernelContext* ctx_;
727   GpuSparse cuda_sparse_;
728   bool initialized_;
729   bool transpose_a_;
730   bool adjoint_a_;
731   bool transpose_b_;
732   GpuSparseMatrixDescriptor descrA_;
733   GpuSparseMatrixDescriptor descrB_;
734   GpuSparseMatrixDescriptor descrC_;
735   gpusparseOperation_t transA_;
736   gpusparseOperation_t transB_;
737 #if GOOGLE_CUDA && (CUDA_VERSION >= 10000)
738   csrgemm2Info_t info_;
739 #endif
740 };
741 
742 }  // namespace functor
743 
744 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
745 
746 }  // namespace tensorflow
747