• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1  /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2  
3  Licensed under the Apache License, Version 2.0 (the "License");
4  you may not use this file except in compliance with the License.
5  You may obtain a copy of the License at
6  
7      http://www.apache.org/licenses/LICENSE-2.0
8  
9  Unless required by applicable law or agreed to in writing, software
10  distributed under the License is distributed on an "AS IS" BASIS,
11  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  See the License for the specific language governing permissions and
13  limitations under the License.
14  ==============================================================================*/
15  
16  #define EIGEN_USE_THREADS
17  
18  #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
19  #define EIGEN_USE_GPU
20  #endif
21  
22  #include "third_party/eigen3/Eigen/Core"
23  #include "third_party/eigen3/Eigen/SparseCore"
24  #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25  #include "tensorflow/core/framework/op.h"
26  #include "tensorflow/core/framework/op_kernel.h"
27  #include "tensorflow/core/framework/tensor_types.h"
28  #include "tensorflow/core/framework/variant_op_registry.h"
29  #include "tensorflow/core/kernels/cwise_ops_common.h"
30  #include "tensorflow/core/kernels/dense_update_functor.h"
31  #include "tensorflow/core/kernels/fill_functor.h"
32  #include "tensorflow/core/kernels/sparse/kernels.h"
33  #include "tensorflow/core/kernels/sparse/sparse_matrix.h"
34  #include "tensorflow/core/kernels/sparse/transpose_op.h"
35  #include "tensorflow/core/kernels/transpose_functor.h"
36  #include "tensorflow/core/lib/gtl/inlined_vector.h"
37  #include "tensorflow/core/platform/threadpool.h"
38  
39  #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
40  #include "tensorflow/core/kernels/cuda_solvers.h"
41  #include "tensorflow/core/kernels/cuda_sparse.h"
42  #endif
43  
44  namespace tensorflow {
45  
46  // TODO(anudhyan): These constants may be tuned based on the performance of
47  // 'benchmark_sparse_matrix_mat_vec_mul'. We would like to find constants
48  // which work across hardware platforms for typical matrix sizes. It should be
49  // possible to observe at least 30-50% improvement as we increase the number
50  // of threads by 1. If not, then it may we worth increasing kMaxShards and
51  // kNumShardsPerThread. However, once we have too many shards, latency may be
52  // dominated by per-shard overhead.
53  //
54  // Maximum number of shards into which to divide the computation for each CSR
55  // Sparse Matrix instance.
56  static constexpr int32 kMaxShards = 20;
57  // Number of shards allocated to each thread.
58  static constexpr int32 kNumShardsPerThread = 3;
59  
60  typedef Eigen::ThreadPoolDevice CPUDevice;
61  typedef Eigen::GpuDevice GPUDevice;
62  
63  // Abstract OpKernel to compute sparse-dense matrix multiplication.
64  //
65  // Implements a kernel which, given a SparseMatrix `a` and dense Tensor `b`,
66  // computes a dense Tensor `c` satisfying `c = a * b` where * denotes matrix
67  // multiplication.
68  //
69  // The boolean attributes `transpose_a` and `adjoint_a` will transpose or
70  // adjoint `a` before multiplication, respectively. At most one of these
71  // attributes must be set to True. Corresponding attributes will transpose or
72  // adjoint `b` or the output (after multiplication).
73  //
74  // The rank of both `a` and `b` must be equal and their shapes must be
75  // compatible for matrix multiplication. Otherwise, InvalidArgument runtime
76  // errors will be thrown. Only rank 2 or rank 3 inputs are supported.
77  //
78  template <typename Device, typename T>
79  class CSRMatMulOp : public OpKernel {
80   public:
CSRMatMulOp(OpKernelConstruction * c)81    explicit CSRMatMulOp(OpKernelConstruction* c) : OpKernel(c) {
82      OP_REQUIRES_OK(c, c->GetAttr("transpose_a", &transpose_a_));
83      OP_REQUIRES_OK(c, c->GetAttr("transpose_b", &transpose_b_));
84      bool adjoint_a;
85      OP_REQUIRES_OK(c, c->GetAttr("adjoint_a", &adjoint_a));
86      OP_REQUIRES(c, !(adjoint_a && transpose_a_),
87                  errors::InvalidArgument(
88                      "Only one of adjoint_a and transpose_a may be true."));
89      bool adjoint_b;
90      OP_REQUIRES_OK(c, c->GetAttr("adjoint_b", &adjoint_b));
91      OP_REQUIRES(c, !(adjoint_b && transpose_b_),
92                  errors::InvalidArgument(
93                      "Only one of adjoint_b and transpose_b may be true."));
94      OP_REQUIRES_OK(c, c->GetAttr("transpose_output", &transpose_output_));
95      OP_REQUIRES_OK(c, c->GetAttr("conjugate_output", &conjugate_output_));
96      conjugate_a_ = adjoint_a;
97      conjugate_b_ = adjoint_b;
98      transpose_a_ |= adjoint_a;
99      transpose_b_ |= adjoint_b;
100    }
101  
~CSRMatMulOp()102    ~CSRMatMulOp() override {}
103  
ValidateInputs(const CSRSparseMatrix & sparse_matrix_a,const Tensor & dense_tensor_b,int * rank,int64 * batch_size)104    Status ValidateInputs(const CSRSparseMatrix& sparse_matrix_a,
105                          const Tensor& dense_tensor_b, int* rank,
106                          int64* batch_size) {
107      if (sparse_matrix_a.dtype() != dense_tensor_b.dtype()) {
108        return errors::InvalidArgument(
109            "Input types don't match.  a.dtype == ",
110            DataTypeString(sparse_matrix_a.dtype()),
111            " vs. b.dtype == ", DataTypeString(dense_tensor_b.dtype()));
112      }
113      *rank = sparse_matrix_a.dims();
114      // TODO(ebrevdo): Add support for broadcasting matmul.
115      if (*rank != dense_tensor_b.dims()) {
116        return errors::InvalidArgument("Ranks of a and b must match, saw: ", rank,
117                                       " vs. ", dense_tensor_b.dims(), ".");
118      }
119      // A valid CSR SparseMatrix has rank 2 or rank 3.
120      *batch_size = (*rank == 2) ? 1 : dense_tensor_b.dim_size(0);
121      if (sparse_matrix_a.batch_size() != *batch_size) {
122        return errors::InvalidArgument("Batch sizes of a and b must match, saw: ",
123                                       sparse_matrix_a.batch_size(), " vs. ",
124                                       batch_size, ".");
125      }
126      const auto& a_dense_shape = sparse_matrix_a.dense_shape().vec<int64>();
127      const int64 a_inner_dim =
128          a_dense_shape(this->transpose_a_ ? *rank - 2 : *rank - 1);
129      const int64 b_inner_dim =
130          dense_tensor_b.dim_size(this->transpose_b_ ? *rank - 1 : *rank - 2);
131      if (a_inner_dim != b_inner_dim) {
132        return errors::InvalidArgument(
133            "Inner product dimensions of A and B do not agree.  Shapes are: ",
134            TensorShape(a_dense_shape), " vs. ",
135            dense_tensor_b.shape().DebugString());
136      }
137      return Status::OK();
138    }
139  
140   public:
141    bool transpose_a_;
142    bool transpose_b_;
143    bool conjugate_a_;
144    bool conjugate_b_;
145    bool transpose_output_;
146    bool conjugate_output_;
147  };
148  
149  // CPU Kernel to compute sparse-dense matrix multiplication.
150  //
151  // Uses Eigen SparseMatrix to compute the sparse-dense multiplication between
152  // a CSR SparseMatrix `a` and dense Tensor `b`. If intra-op parallelism is
153  // available, the implementation parallelizes the computation across each row
154  // of the sparse matrix.
155  template <typename T>
156  class CSRMatMulCPUOp : public CSRMatMulOp<CPUDevice, T> {
157    using SparseMatrix = Eigen::SparseMatrix<T, Eigen::RowMajor>;
158    using Matrix =
159        Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
160    using ConstMatrixMap = Eigen::Map<const Matrix>;
161    using MatrixMap = Eigen::Map<Matrix>;
162  
163   public:
CSRMatMulCPUOp(OpKernelConstruction * c)164    explicit CSRMatMulCPUOp(OpKernelConstruction* c)
165        : CSRMatMulOp<CPUDevice, T>(c) {}
166  
~CSRMatMulCPUOp()167    ~CSRMatMulCPUOp() override {}
168  
Compute(OpKernelContext * ctx)169    void Compute(OpKernelContext* ctx) final {
170      const CSRSparseMatrix* sparse_matrix_a;
171      OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &sparse_matrix_a));
172      const Tensor& matrix_b = ctx->input(1);
173  
174      int rank;
175      int64 batch_size;
176      OP_REQUIRES_OK(ctx, this->ValidateInputs(*sparse_matrix_a, matrix_b, &rank,
177                                               &batch_size));
178  
179      const auto dense_shape = sparse_matrix_a->dense_shape().vec<int64>();
180      int64 num_lhs_rows = dense_shape(rank - 2);
181      int64 num_lhs_cols = dense_shape(rank - 1);
182      int64 num_rhs_rows = matrix_b.dim_size(rank - 2);
183      int64 num_rhs_cols = matrix_b.dim_size(rank - 1);
184  
185      if (this->transpose_a_) {
186        std::swap(num_lhs_rows, num_lhs_cols);
187      }
188  
189      // Possibly transpose the dense Tensor b.
190      const Tensor* rhs = &matrix_b;
191      Tensor b_transposed;
192      if (this->transpose_b_) {
193        OP_REQUIRES_OK(
194            ctx, TransposeAndConjugateTensor(ctx, matrix_b, this->conjugate_b_,
195                                             &b_transposed));
196        rhs = &b_transposed;
197        std::swap(num_rhs_rows, num_rhs_cols);
198      }
199  
200      // If we're transposing the output, then allocate a temporary buffer to
201      // store the output. Otherwise allocate the output directly.
202      Tensor* output = nullptr;
203      Tensor* matmul_result = nullptr;
204      Tensor output_transposed;
205      OP_REQUIRES_OK(
206          ctx, AllocateOutput(ctx, rank, batch_size, num_lhs_rows, num_rhs_cols,
207                              this->transpose_output_, &output,
208                              &output_transposed, &matmul_result));
209  
210      if (!this->transpose_a_) {
211        SparseDenseMatMulWithoutTransposedLHS(
212            ctx, batch_size, num_lhs_rows, *sparse_matrix_a, *rhs, matmul_result);
213      } else {  // transpose_a_ == true
214        SparseDenseMatMulWithTransposedLHS(ctx, batch_size, num_lhs_rows,
215                                           num_lhs_cols, *sparse_matrix_a, *rhs,
216                                           matmul_result);
217      }
218  
219      // Transpose (and conjugate) the output if necessary.
220      // Note that conjugate is only true if transpose is also true.
221      if (this->transpose_output_) {
222        OP_REQUIRES_OK(
223            ctx, TransposeAndConjugateAllocatedTensor(
224                     ctx, output_transposed, this->conjugate_output_, output));
225      } else if (this->conjugate_output_) {
226        functor::maybe_conj_inplace<CPUDevice, T>::run(
227            ctx->eigen_device<CPUDevice>(), output);
228      }
229    }
230  
231   private:
232    // Allocates the output with the appropriate shape. Additionally, if
233    // transpose_output is True, allocates a temporary buffer with the transposed
234    // output. 'matmul_result' points to either output or output_transposed, based
235    // on whether transpose_output is True.
AllocateOutput(OpKernelContext * ctx,const int32 rank,const int64 batch_size,const int64 num_rows,const int64 num_cols,const bool transpose_output,Tensor ** output,Tensor * output_transposed,Tensor ** matmul_result)236    Status AllocateOutput(OpKernelContext* ctx, const int32 rank,
237                          const int64 batch_size, const int64 num_rows,
238                          const int64 num_cols, const bool transpose_output,
239                          Tensor** output, Tensor* output_transposed,
240                          Tensor** matmul_result) {
241      TensorShape output_shape;
242      if (rank == 3) output_shape.AddDim(batch_size);
243  
244      if (!transpose_output) {
245        output_shape.AppendShape({num_rows, num_cols});
246        TF_RETURN_IF_ERROR(ctx->allocate_output(0, output_shape, output));
247        *matmul_result = *output;
248      } else {
249        TensorShape output_transposed_shape = output_shape;
250        output_transposed_shape.AppendShape({num_rows, num_cols});
251        output_shape.AppendShape({num_cols, num_rows});
252        TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value,
253                                              output_transposed_shape,
254                                              output_transposed));
255        TF_RETURN_IF_ERROR(ctx->allocate_output(0, output_shape, output));
256        *matmul_result = output_transposed;
257      }
258      return Status::OK();
259    }
260  
261    // Returns an Eigen::Ref expression of a sparse sub-matrix from the given
262    // contiguous segment of rows of the CSR Sparse Matrix.
GetSparseMatrixRef(const CSRSparseMatrix & csr_matrix,const int batch_index,const int64 row_begin,const int64 num_shard_rows,std::vector<int32> * row_ptrs)263    Eigen::Ref<const SparseMatrix> GetSparseMatrixRef(
264        const CSRSparseMatrix& csr_matrix, const int batch_index,
265        const int64 row_begin, const int64 num_shard_rows,
266        std::vector<int32>* row_ptrs) {
267      // Compute the row pointers of the sparse sub-matrix.
268      row_ptrs->resize(num_shard_rows + 1);
269      const int64 row_offset =
270          csr_matrix.row_pointers_vec(batch_index)(row_begin);
271      for (int64 row_idx = 0; row_idx <= num_shard_rows; ++row_idx) {
272        row_ptrs->at(row_idx) =
273            csr_matrix.row_pointers_vec(batch_index)(row_begin + row_idx) -
274            row_offset;
275      }
276      const int64 num_cols =
277          csr_matrix.dense_shape().vec<int64>()(csr_matrix.dims() - 1);
278      return Eigen::Map<const SparseMatrix>(
279          num_shard_rows /* num_rows */, num_cols /* num_cols */,
280          row_ptrs->at(num_shard_rows) /* total_nnz */, row_ptrs->data(),
281          csr_matrix.col_indices_vec(batch_index).data() + row_offset,
282          csr_matrix.values_vec<T>(batch_index).data() + row_offset);
283    }
284  
285    // Sparse-Dense Matrix Multiplication between a CSRSparseMatrix (LHS) and a
286    // dense Tensor (RHS).
SparseDenseMatMulWithoutTransposedLHS(OpKernelContext * ctx,const int64 batch_size,const int64 num_lhs_rows,const CSRSparseMatrix & lhs,const Tensor & rhs,Tensor * output)287    void SparseDenseMatMulWithoutTransposedLHS(
288        OpKernelContext* ctx, const int64 batch_size, const int64 num_lhs_rows,
289        const CSRSparseMatrix& lhs, const Tensor& rhs, Tensor* output) {
290      // Parallelize matrix multiplication across batch dimensions and across
291      // rows in each batch.
292      auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
293      const int32 num_threads = worker_threads.num_threads;
294      const int64 block_size =
295          num_lhs_rows / std::max(kMaxShards, kNumShardsPerThread * num_threads);
296      const int64 num_rhs_rows = rhs.dim_size(rhs.dims() - 2);
297      const int64 num_rhs_cols = rhs.dim_size(rhs.dims() - 1);
298      worker_threads.workers->ParallelFor(
299          batch_size * num_lhs_rows /* total */,
300          thread::ThreadPool::SchedulingParams(
301              thread::ThreadPool::SchedulingStrategy::
302                  kFixedBlockSize /* strategy */,
303              absl::nullopt /* cost_per_unit */, block_size),
304          [&](int64 batch_and_row_begin, int64 batch_and_row_end) {
305            HandleBatchAndRowRange(
306                num_lhs_rows, batch_and_row_begin, batch_and_row_end,
307                [&](int64 batch_idx, int64 row_begin, int64 row_end) {
308                  const int64 num_shard_rows = row_end - row_begin;
309  
310                  // Define an Eigen::SparseMatrix over the row range:
311                  // [row_begin, row_end) of the CSR SparseMatrix A.
312                  std::vector<int32> row_ptrs;
313                  auto sparse_matrix = GetSparseMatrixRef(
314                      lhs, batch_idx, row_begin, num_shard_rows, &row_ptrs);
315  
316                  // Map the corresponding rows of the rhs.
317                  ConstMatrixMap rhs_map(rhs.flat<T>().data() + batch_idx *
318                                                                    num_rhs_rows *
319                                                                    num_rhs_cols,
320                                         num_rhs_rows, num_rhs_cols);
321  
322                  // Write to the corresponding rows of the output matrix.
323                  MatrixMap output_map(
324                      output->flat<T>().data() +
325                          batch_idx * num_lhs_rows * num_rhs_cols +
326                          row_begin * num_rhs_cols,
327                      num_shard_rows, num_rhs_cols);
328                  output_map.noalias() = sparse_matrix * rhs_map;
329                });
330          });
331    }
332  
333    // Sparse-Dense Matrix Multiplication assuming the CSRSparseMatrix (LHS) is
334    // to be transposed before the operation.
SparseDenseMatMulWithTransposedLHS(OpKernelContext * ctx,const int64 batch_size,const int64 num_lhs_rows,const int64 num_lhs_cols,const CSRSparseMatrix & lhs,const Tensor & rhs,Tensor * output)335    void SparseDenseMatMulWithTransposedLHS(OpKernelContext* ctx,
336                                            const int64 batch_size,
337                                            const int64 num_lhs_rows,
338                                            const int64 num_lhs_cols,
339                                            const CSRSparseMatrix& lhs,
340                                            const Tensor& rhs, Tensor* output) {
341      auto device = ctx->eigen_device<CPUDevice>();
342      auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
343      const int32 num_threads = worker_threads.num_threads;
344      const int64 num_rhs_rows = rhs.dim_size(rhs.dims() - 2);
345      const int64 num_rhs_cols = rhs.dim_size(rhs.dims() - 1);
346      // Usually, we want to avoid transposing the sparse matrix A since it may be
347      // an expensive operation. Instead, we use the identity (A^T B) = (B^T A)^T.
348      // We don't actually transpose B or the output because it is more convenient
349      // to have them in column major form.
350      //
351      // However, if A is hypersparse and B and C are huge, transposing A will be
352      // cheaper. In the future, we should have a cost model estimating the cost
353      // of transposing all matrices (A, B, C) to decide which variant to use.
354  
355      // Each thread writes to its own copy of the matrix product. These
356      // `num_threads` copies are summed together to obtain the final result.
357      Tensor matmul_result_buffer;
358      OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
359                                             TensorShape({num_threads + 1,
360                                                          output->NumElements()}),
361                                             &matmul_result_buffer));
362      functor::SetZeroFunctor<CPUDevice, T> set_zero;
363      set_zero(device, matmul_result_buffer.flat<T>());
364  
365      // Parallelize matrix multiplication across batch dimensions and across
366      // columns of A^T in each batch. These correspond to rows of A.
367      const int64 block_size =
368          num_lhs_cols / std::max(kMaxShards, kNumShardsPerThread * num_threads);
369      worker_threads.workers->ParallelForWithWorkerId(
370          batch_size * num_lhs_cols /* total */,
371          thread::ThreadPool::SchedulingParams(
372              thread::ThreadPool::SchedulingStrategy::
373                  kFixedBlockSize /* strategy */,
374              absl::nullopt /* cost_per_unit */, block_size),
375          [&](int64 batch_and_row_begin, int64 batch_and_row_end, int tid) {
376            HandleBatchAndRowRange(
377                num_lhs_cols, batch_and_row_begin, batch_and_row_end,
378                [&](int64 batch_idx, int64 row_begin, int64 row_end) {
379                  const int64 num_shard_rows = row_end - row_begin;
380  
381                  // Define a new sparse sub-matrix from the row range
382                  // [row_begin, row_end) of the sparse matrix A.
383                  std::vector<int32> row_ptrs;
384                  auto sparse_matrix = GetSparseMatrixRef(
385                      lhs, batch_idx, row_begin, num_shard_rows, &row_ptrs);
386  
387                  // Map the corresponding `num_shard_rows` columns of B^T.
388                  // This is the same as taking the `num_shard_rows` rows of B.
389                  ConstMatrixMap b_dense_map(
390                      rhs.flat<T>().data() +
391                          batch_idx * num_rhs_rows * num_rhs_cols +
392                          row_begin * num_rhs_cols,
393                      num_shard_rows, num_rhs_cols);
394  
395                  // Map to the corresponding rows of the output.
396                  MatrixMap output_map(
397                      matmul_result_buffer.flat<T>().data() +
398                          tid * batch_size * num_lhs_rows * num_rhs_cols +
399                          batch_idx * num_lhs_rows * num_rhs_cols,
400                      num_lhs_rows, num_rhs_cols);
401  
402                  // Compute the product C^T = B^T * A; restricted to the row
403                  // range in the current shard.
404                  if (this->conjugate_a_) {
405                    output_map.transpose().noalias() +=
406                        b_dense_map.transpose() * sparse_matrix.conjugate();
407                  } else {
408                    output_map.transpose().noalias() +=
409                        b_dense_map.transpose() * sparse_matrix;
410                  }
411                });
412          });
413  
414      // Sum across each thread's matmul result.
415      using Reducer = Eigen::internal::SumReducer<T>;
416      using Index = typename TTypes<T>::Tensor::Index;
417      output->flat<T>().device(device) = matmul_result_buffer.matrix<T>().reduce(
418          Eigen::array<Index, 1>({0}), Reducer());
419    }
420  
421    // Given a range [batch_and_row_begin, batch_and_row_end) which is a
422    // contiguous subset of [0, num_rows * batch_size), calls the function
423    // fn(batch_idx, row_begin, row_end) for each batch index
424    // and the row range [row_begin, row_end) contained in the batch.
HandleBatchAndRowRange(const int64 num_rows,const int64 batch_and_row_begin,const int64 batch_and_row_end,const std::function<void (int64,int64,int64)> & fn)425    void HandleBatchAndRowRange(
426        const int64 num_rows, const int64 batch_and_row_begin,
427        const int64 batch_and_row_end,
428        const std::function<void(int64, int64, int64)>& fn) {
429      // Obtain the batch indices overlapping with the current shard.
430      const int64 batch_begin = batch_and_row_begin / num_rows;
431      const int64 batch_end_inclusive = batch_and_row_end / num_rows;
432  
433      for (int64 batch_idx = batch_begin; batch_idx <= batch_end_inclusive;
434           ++batch_idx) {
435        // Find the contiguous set of rows which are contained in this shard as
436        // well as the current batch. We intersect with interval [batch_idx *
437        // num_rows, (batch_idx + 1) * num_rows) which denotes the current batch.
438        const int64 current_batch_row_begin =
439            std::max(batch_and_row_begin, batch_idx * num_rows);
440        const int64 current_batch_row_end =
441            std::min(batch_and_row_end, (batch_idx + 1) * num_rows);
442  
443        const int64 row_begin = current_batch_row_begin % num_rows;
444        const int64 num_shard_rows =
445            current_batch_row_end - current_batch_row_begin;
446        // Edge case for when current_batch_row_end is the first index of a new
447        // row.
448        if (num_shard_rows == 0) continue;
449  
450        fn(batch_idx, row_begin, row_begin + num_shard_rows);
451      }
452    }
453  
454    // Transposes (and optionally, conjugates) a given Tensor. Also allocates the
455    // required memory for the output Tensor.
TransposeAndConjugateTensor(OpKernelContext * ctx,const Tensor & input,bool conjugate,Tensor * output)456    Status TransposeAndConjugateTensor(OpKernelContext* ctx, const Tensor& input,
457                                       bool conjugate, Tensor* output) {
458      TensorShape transposed_shape = input.shape();
459      transposed_shape.set_dim(input.dims() - 1,
460                               input.dim_size(input.dims() - 2));
461      transposed_shape.set_dim(input.dims() - 2,
462                               input.dim_size(input.dims() - 1));
463      TF_RETURN_IF_ERROR(
464          ctx->allocate_temp(DataTypeToEnum<T>::value, transposed_shape, output));
465      return TransposeAndConjugateAllocatedTensor(ctx, input, conjugate, output);
466    }
467  
468    // Transposes (and optionally, conjugates) a given Tensor. The output should
469    // be already allocated.
TransposeAndConjugateAllocatedTensor(OpKernelContext * ctx,const Tensor & input,bool conjugate,Tensor * output)470    Status TransposeAndConjugateAllocatedTensor(OpKernelContext* ctx,
471                                                const Tensor& input,
472                                                bool conjugate, Tensor* output) {
473      if (conjugate) {
474        TF_RETURN_IF_ERROR(DoConjugateMatrixTranspose(
475            ctx->eigen_device<CPUDevice>(), input, output));
476      } else {
477        TF_RETURN_IF_ERROR(
478            DoMatrixTranspose(ctx->eigen_device<CPUDevice>(), input, output));
479      }
480      return Status::OK();
481    }
482  };
483  
484  // GPU Kernel to compute sparse-dense matrix multiplication.
485  template <typename T>
486  class CSRMatMulGPUOp : public CSRMatMulOp<GPUDevice, T> {
487    using SparseMatrix = Eigen::SparseMatrix<T, Eigen::RowMajor>;
488    using Matrix =
489        Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
490    using ConstMatrixMap = Eigen::Map<const Matrix>;
491    using MatrixMap = Eigen::Map<Matrix>;
492  
493   public:
CSRMatMulGPUOp(OpKernelConstruction * c)494    explicit CSRMatMulGPUOp(OpKernelConstruction* c)
495        : CSRMatMulOp<GPUDevice, T>(c) {}
496  
~CSRMatMulGPUOp()497    ~CSRMatMulGPUOp() override {}
498  
Compute(OpKernelContext * ctx)499    void Compute(OpKernelContext* ctx) final {
500      const CSRSparseMatrix* a_matrix;
501      OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &a_matrix));
502      const Tensor& b_t = ctx->input(1);
503  
504      int rank;
505      int64 batch_size;
506      OP_REQUIRES_OK(ctx,
507                     this->ValidateInputs(*a_matrix, b_t, &rank, &batch_size));
508  
509      const Tensor& a_dense_shape_t = a_matrix->dense_shape();
510      TensorShape a_dense_tensor_shape;
511      auto a_dense_shape = a_dense_shape_t.vec<int64>();
512      OP_REQUIRES_OK(
513          ctx, TensorShapeUtils::MakeShape(a_dense_shape, &a_dense_tensor_shape));
514  
515      const int row_dim = (rank == 2) ? 0 : 1;
516      const int64 a_outer_dim = a_dense_tensor_shape.dim_size(
517          this->transpose_a_ ? row_dim + 1 : row_dim);
518      const int64 b_inner_dim =
519          b_t.shape().dim_size(this->transpose_b_ ? row_dim + 1 : row_dim);
520      const int64 b_outer_dim =
521          b_t.dim_size(this->transpose_b_ ? row_dim : row_dim + 1);
522      const int64 b_slice_size = b_inner_dim * b_outer_dim;
523  
524      TensorShape c_shape;
525      if (rank == 3) c_shape.AddDim(batch_size);
526      if (this->transpose_output_) {
527        c_shape.AddDim(b_outer_dim);
528        c_shape.AddDim(a_outer_dim);
529      } else {
530        c_shape.AddDim(a_outer_dim);
531        c_shape.AddDim(b_outer_dim);
532      }
533  
534      const int64 c_matrix_lhs = c_shape.dim_size(row_dim);
535      const int64 c_matrix_rhs = c_shape.dim_size(row_dim + 1);
536      const int64 c_slice_size = c_matrix_lhs * c_matrix_rhs;
537      Tensor* c_t;
538      OP_REQUIRES_OK(ctx, ctx->allocate_output(0, c_shape, &c_t));
539  
540      const GPUDevice& d = ctx->eigen_device<GPUDevice>();
541  
542      if (b_outer_dim == 1) {
543        // Call matrix-vector multiply if b is a vector.
544        TTypes<int64>::ConstVec a_dense_shape_comp(a_dense_shape.data() + row_dim,
545                                                   2);
546        Tensor b_conj_t;
547        const T* b_base_ptr = b_t.template flat<T>().data();
548        bool conjugate_a = this->conjugate_a_;
549        bool conjugate_output = this->conjugate_output_;
550        if (this->conjugate_b_) {
551          if (conjugate_a) {
552            // In this case we can use the identity
553            //   conj(a) * conj(b) = conj(a * b)
554            // instead of creating a conjugated copy of b.
555            conjugate_a = false;
556            conjugate_output = !conjugate_output;
557          } else {
558            OP_REQUIRES_OK(
559                ctx, ctx->forward_input_or_allocate_temp(
560                         {1}, DataTypeToEnum<T>::value, b_t.shape(), &b_conj_t));
561            functor::maybe_conj<GPUDevice, T>::run(d, b_t, &b_conj_t);
562            b_base_ptr = b_conj_t.template flat<T>().data();
563          }
564        }
565  
566        functor::CSRSparseMatrixMatVec<GPUDevice, T> csr_spmv(this->transpose_a_,
567                                                              conjugate_a);
568        for (int i = 0; i < batch_size; ++i) {
569          auto a_row_ptr = a_matrix->row_pointers_vec(i);
570          auto a_col_ind = a_matrix->col_indices_vec(i);
571          auto a_values = a_matrix->values_vec<T>(i);
572          ConstCSRComponent<T> a_comp{a_row_ptr, a_col_ind, a_values,
573                                      a_dense_shape_comp};
574          const T* b_i = b_base_ptr + i * b_slice_size;
575          T* c_i = &c_t->template flat<T>()(i * c_slice_size);
576          Status s = csr_spmv.Compute(ctx, a_comp, b_i, c_i);
577          OP_REQUIRES_OK(ctx, s);
578        }
579        if (conjugate_output) {
580          functor::maybe_conj_inplace<GPUDevice, T>::run(d, c_t);
581        }
582        return;
583      }
584  
585      functor::CSRSparseMatrixMatMul<GPUDevice, T> csr_spmmadd(
586          this->transpose_output_);
587  
588      Tensor c_mat_col_major_t;
589      if (!this->transpose_output_) {
590        // If transpose_output is false, we'll need to transpose the (col
591        // major) output of the csrgemm call to get proper (row-major)
592        // output.  Which means we need to keep a temporary buffer to
593        // store the intermediate gemm output.
594        TensorShape c_mat_col_major_shape;
595        if (rank == 2) {
596          c_mat_col_major_shape = TensorShape({c_matrix_rhs, c_matrix_lhs});
597        } else {
598          c_mat_col_major_shape =
599              TensorShape({batch_size, c_matrix_rhs, c_matrix_lhs});
600        }
601        OP_REQUIRES_OK(
602            ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
603                                    c_mat_col_major_shape, &c_mat_col_major_t));
604      }
605  
606      // If transpose_output is true, return the direct (column-major i.e.,
607      // transposed) output of the csrgemm call.  Otherwise we'll need
608      // to transpose it to row major format.
609      auto c_mat_col_major = (this->transpose_output_)
610                                 ? c_t->flat<T>()
611                                 : c_mat_col_major_t.flat<T>();
612  
613      // Possibly transpose a.
614      const CSRSparseMatrix* a_input_matrix;
615      // If we need to transpose a, we will store the result temporarily
616      // in the object below.
617      CSRSparseMatrix a_matrix_transposed;
618      if (!this->transpose_a_) {
619        a_input_matrix = a_matrix;
620      } else {
621        functor::CSRSparseMatrixTranspose<GPUDevice, T> transpose;
622        OP_REQUIRES_OK(ctx, transpose(ctx, this->conjugate_a_, *a_matrix,
623                                      &a_matrix_transposed));
624        a_input_matrix = &a_matrix_transposed;
625      }
626  
627      auto a_input_dense_shape = a_input_matrix->dense_shape().vec<int64>();
628  
629      // Possibly transpose b.
630      Tensor b_t_input;
631      if (!this->transpose_b_) {
632        b_t_input = b_t;
633      } else {
634        TensorShape b_t_transposed_shape;
635        if (rank == 3) {
636          b_t_transposed_shape.AddDim(batch_size);
637        }
638        b_t_transposed_shape.AddDim(b_t.dim_size(row_dim + 1));
639        b_t_transposed_shape.AddDim(b_t.dim_size(row_dim));
640        OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
641                                               b_t_transposed_shape, &b_t_input));
642        const GPUDevice& d = ctx->eigen_device<GPUDevice>();
643        if (this->conjugate_b_) {
644          OP_REQUIRES_OK(ctx, DoConjugateMatrixTranspose(d, b_t /*input*/,
645                                                         &b_t_input /*output*/));
646        } else {
647          OP_REQUIRES_OK(
648              ctx, DoMatrixTranspose(d, b_t /*input*/, &b_t_input /*output*/));
649        }
650      }
651  
652      // Dense shape of a batch component of A.
653      TTypes<int64>::ConstVec a_input_dense_shape_comp(
654          a_input_dense_shape.data() + row_dim, 2);
655  
656      auto b = b_t_input.flat<T>();
657  
658      for (int i = 0; i < batch_size; ++i) {
659        auto a_row_ptr = a_input_matrix->row_pointers_vec(i);
660        auto a_col_ind = a_input_matrix->col_indices_vec(i);
661        auto a_values = a_input_matrix->values_vec<T>(i);
662        typename TTypes<T>::UnalignedConstMatrix b_i(b.data() + i * b_slice_size,
663                                                     {b_inner_dim, b_outer_dim});
664        typename TTypes<T>::UnalignedMatrix c_mat_col_major_i(
665            c_mat_col_major.data() + i * c_slice_size,
666            {c_matrix_lhs, c_matrix_rhs});
667        ConstCSRComponent<T> a_comp{a_row_ptr, a_col_ind, a_values,
668                                    a_input_dense_shape_comp};
669        Status s = csr_spmmadd.Compute(ctx, a_comp, b_i, c_mat_col_major_i);
670        OP_REQUIRES_OK(ctx, s);
671      }
672  
673      if (!this->transpose_output_) {
674        // We need to return values in row major format, so transpose
675        // the column-major values in c_mat_col_major_t to row-major output c_t.
676        OP_REQUIRES_OK(ctx, DoMatrixTranspose(d, /*input=*/c_mat_col_major_t,
677                                              /*output=*/c_t));
678      }
679      if (this->conjugate_output_) {
680        functor::maybe_conj_inplace<GPUDevice, T>::run(d, c_t);
681      }
682    }
683  };
684  
685  #define REGISTER_CPU(T)                                                     \
686    REGISTER_KERNEL_BUILDER(                                                  \
687        Name("SparseMatrixMatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
688        CSRMatMulCPUOp<T>);
689  
690  REGISTER_CPU(float)
691  REGISTER_CPU(double)
692  REGISTER_CPU(complex64)
693  REGISTER_CPU(complex128)
694  
695  #undef REGISTER_CPU
696  
697  #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
698  
699  #define REGISTER_GPU(T)                                                     \
700    REGISTER_KERNEL_BUILDER(                                                  \
701        Name("SparseMatrixMatMul").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
702        CSRMatMulGPUOp<T>);
703  
704  REGISTER_GPU(float)
705  REGISTER_GPU(double)
706  #if GOOGLE_CUDA
707  REGISTER_GPU(complex64)
708  REGISTER_GPU(complex128)
709  #endif
710  
711  #undef REGISTER_GPU
712  
713  #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
714  
715  #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
716  
717  namespace functor {
718  
719  template <typename T>
720  class CSRSparseMatrixMatMul<GPUDevice, T> {
721   public:
CSRSparseMatrixMatMul(const bool transpose_output)722    explicit CSRSparseMatrixMatMul(const bool transpose_output)
723        : transpose_output_(transpose_output) {}
724  
Compute(OpKernelContext * ctx,const ConstCSRComponent<T> & a,typename TTypes<T>::UnalignedConstMatrix b,typename TTypes<T>::UnalignedMatrix c)725    Status Compute(OpKernelContext* ctx, const ConstCSRComponent<T>& a,
726                   typename TTypes<T>::UnalignedConstMatrix b,
727                   typename TTypes<T>::UnalignedMatrix c) {
728      GpuSparse cuda_sparse(ctx);
729      TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
730      {
731        // Use Csrmm to calculate:
732        //   C = alpha * op(A) * op(B) + beta * C
733        // where alpha = 1.0, beta = 0.0, A is sparse and B and C are dense.
734        // Note that Csrmm assumes B and C are in column-major form; so we
735        // use transB == true, and manually transpose the output in place
736        // using blas<t>geam.
737        // TODO(ebrevdo,rmlarsen): Add support for transposition and adjoint.
738  
739        // Create alpha and beta scalars; alpha = 1.0, beta = 0.0
740        // TODO(ebrevdo,rmlarsen): Add support for non-trivial alpha and beta.
741        const T alpha = 1;
742        const T beta = 0;
743  
744        // transA must be non-transpose if transB is transpose (cusparse
745        // limitation).
746  #if GOOGLE_CUDA
747        const gpusparseOperation_t transA = CUSPARSE_OPERATION_NON_TRANSPOSE;
748  #elif TENSORFLOW_USE_ROCM
749        const gpusparseOperation_t transA = HIPSPARSE_OPERATION_NON_TRANSPOSE;
750  #endif
751  
752        // transB: b is row-major, and cusparse requires col-major b (or
753        // equivalently transB == transpose).  this version is actually more
754        // efficient.
755  #if GOOGLE_CUDA
756        const gpusparseOperation_t transB = CUSPARSE_OPERATION_TRANSPOSE;
757  
758        gpusparseMatDescr_t descrA;
759        TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA));
760        TF_RETURN_IF_GPUSPARSE_ERROR(
761            cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL));
762        TF_RETURN_IF_GPUSPARSE_ERROR(
763            cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO));
764  #elif TENSORFLOW_USE_ROCM
765        const gpusparseOperation_t transB = HIPSPARSE_OPERATION_TRANSPOSE;
766  
767        gpusparseMatDescr_t descrA;
768        TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descrA));
769        TF_RETURN_IF_GPUSPARSE_ERROR(
770            hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL));
771        TF_RETURN_IF_GPUSPARSE_ERROR(
772            hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
773  #endif
774  
775        // A is (m, k), Bt is (ldb, k) and Ct is (ldc, n)
776        const int k = b.dimension(0);
777        DCHECK_EQ(k, a.dense_shape_host(1));
778  
779        // If transpose_output_ is true, then the c matrix we receive
780        // here is the direct row major output (into which we will store
781        // csrgemm's col major output).  Otherwise it's a
782        // temporary tensor that will store the column major output that
783        // will eventually be transposed.
784        const int m = c.dimension(transpose_output_ ? 1 : 0);
785        const int n = c.dimension(transpose_output_ ? 0 : 1);
786        DCHECK_EQ(m, a.dense_shape_host(0));
787        DCHECK_EQ(n, b.dimension(1));
788        const int nnz = a.values.size();
789        DCHECK_EQ(nnz, a.col_ind.size());
790  
791        // ldb: leading dimension of B. If op(B)=B, it must be at least max(1, k)
792        // if op(A) = A and at least max (1, m) otherwise. If op(B) != B, it must
793        // be at least max(1, n).
794        const int ldb = n;
795        // ldc: leading dimension of C. It must be at least max(1, m) if
796        // op(A) = A and at least max(1, k) otherwise.
797        const int ldc = m;
798  
799        TF_RETURN_IF_ERROR(
800            cuda_sparse.Csrmm(transA, transB, m, n, k, nnz, &alpha, descrA,
801                              a.values.data(), a.row_ptr.data(), a.col_ind.data(),
802                              b.data(), ldb, &beta, c.data(), ldc));
803      }
804  
805      return Status::OK();
806    }
807  
808   private:
809    bool transpose_output_;
810  };
811  
812  template <typename T>
813  class CSRSparseMatrixMatVec<GPUDevice, T> {
814   public:
CSRSparseMatrixMatVec(bool transpose_a,bool conjugate_a)815    CSRSparseMatrixMatVec(bool transpose_a, bool conjugate_a)
816        : transA_(TransposeAndConjugateToGpuSparseOp(transpose_a, conjugate_a,
817                                                     &status_)) {}
818  
Compute(OpKernelContext * ctx,const ConstCSRComponent<T> & a,const T * x,T * y)819    Status Compute(OpKernelContext* ctx, const ConstCSRComponent<T>& a,
820                   const T* x, T* y) {
821      TF_RETURN_IF_ERROR(status_);
822      GpuSparse cuda_sparse(ctx);
823      TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
824      {
825        // Use Csrmv to calculate:
826        //   y = alpha * op(A) * x + beta * y
827        // where alpha = 1.0, beta = 0.0, A is a sparse matrix and x and y are
828        // dense vectors.
829  
830        // Create alpha and beta scalars; alpha = 1.0, beta = 0.0
831        // TODO(rmlarsen,ebrevdo): Add support for general alpha, beta.
832        const T alpha = 1;
833        const T beta = 0;
834  
835        gpusparseMatDescr_t descrA;
836  #if GOOGLE_CUDA
837        TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA));
838        TF_RETURN_IF_GPUSPARSE_ERROR(
839            cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL));
840        TF_RETURN_IF_GPUSPARSE_ERROR(
841            cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO));
842  #elif TENSORFLOW_USE_ROCM
843        TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descrA));
844        TF_RETURN_IF_GPUSPARSE_ERROR(
845            hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL));
846        TF_RETURN_IF_GPUSPARSE_ERROR(
847            hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
848  #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
849  
850        const int m = a.dense_shape_host(0);
851        const int n = a.dense_shape_host(1);
852        const int nnz = a.values.size();
853        DCHECK_EQ(nnz, a.col_ind.size());
854        TF_RETURN_IF_ERROR(cuda_sparse.Csrmv(transA_, m, n, nnz, &alpha, descrA,
855                                             a.values.data(), a.row_ptr.data(),
856                                             a.col_ind.data(), x, &beta, y));
857      }
858  
859      return Status::OK();
860    }
861  
862   private:
863    Status status_;
864    const gpusparseOperation_t transA_;
865  };
866  
867  }  // namespace functor
868  
869  #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
870  
871  }  // namespace tensorflow
872