1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 19 #define EIGEN_USE_GPU 20 #endif 21 22 #include "third_party/eigen3/Eigen/Core" 23 #include "third_party/eigen3/Eigen/SparseCore" 24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 25 #include "tensorflow/core/framework/op.h" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/framework/tensor_types.h" 28 #include "tensorflow/core/framework/variant_op_registry.h" 29 #include "tensorflow/core/kernels/cwise_ops_common.h" 30 #include "tensorflow/core/kernels/dense_update_functor.h" 31 #include "tensorflow/core/kernels/fill_functor.h" 32 #include "tensorflow/core/kernels/sparse/kernels.h" 33 #include "tensorflow/core/kernels/sparse/sparse_matrix.h" 34 #include "tensorflow/core/kernels/sparse/transpose_op.h" 35 #include "tensorflow/core/kernels/transpose_functor.h" 36 #include "tensorflow/core/lib/gtl/inlined_vector.h" 37 #include "tensorflow/core/platform/threadpool.h" 38 39 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 40 #include "tensorflow/core/util/cuda_solvers.h" 41 #include "tensorflow/core/util/cuda_sparse.h" 42 #endif 43 44 namespace tensorflow { 45 46 // TODO(anudhyan): These constants may be tuned based on the performance of 47 // 'benchmark_sparse_matrix_mat_vec_mul'. We would like to find constants 48 // which work across hardware platforms for typical matrix sizes. It should be 49 // possible to observe at least 30-50% improvement as we increase the number 50 // of threads by 1. If not, then it may we worth increasing kMaxShards and 51 // kNumShardsPerThread. However, once we have too many shards, latency may be 52 // dominated by per-shard overhead. 53 // 54 // Maximum number of shards into which to divide the computation for each CSR 55 // Sparse Matrix instance. 56 static constexpr int32_t kMaxShards = 20; 57 // Number of shards allocated to each thread. 58 static constexpr int32_t kNumShardsPerThread = 3; 59 60 typedef Eigen::ThreadPoolDevice CPUDevice; 61 typedef Eigen::GpuDevice GPUDevice; 62 63 // Abstract OpKernel to compute sparse-dense matrix multiplication. 64 // 65 // Implements a kernel which, given a SparseMatrix `a` and dense Tensor `b`, 66 // computes a dense Tensor `c` satisfying `c = a * b` where * denotes matrix 67 // multiplication. 68 // 69 // The boolean attributes `transpose_a` and `adjoint_a` will transpose or 70 // adjoint `a` before multiplication, respectively. At most one of these 71 // attributes must be set to True. Corresponding attributes will transpose or 72 // adjoint `b` or the output (after multiplication). 73 // 74 // The rank of both `a` and `b` must be equal and their shapes must be 75 // compatible for matrix multiplication. Otherwise, InvalidArgument runtime 76 // errors will be thrown. Only rank 2 or rank 3 inputs are supported. 77 // 78 template <typename Device, typename T> 79 class CSRMatMulOp : public OpKernel { 80 public: CSRMatMulOp(OpKernelConstruction * c)81 explicit CSRMatMulOp(OpKernelConstruction* c) : OpKernel(c) { 82 OP_REQUIRES_OK(c, c->GetAttr("transpose_a", &transpose_a_)); 83 OP_REQUIRES_OK(c, c->GetAttr("transpose_b", &transpose_b_)); 84 bool adjoint_a; 85 OP_REQUIRES_OK(c, c->GetAttr("adjoint_a", &adjoint_a)); 86 OP_REQUIRES(c, !(adjoint_a && transpose_a_), 87 errors::InvalidArgument( 88 "Only one of adjoint_a and transpose_a may be true.")); 89 bool adjoint_b; 90 OP_REQUIRES_OK(c, c->GetAttr("adjoint_b", &adjoint_b)); 91 OP_REQUIRES(c, !(adjoint_b && transpose_b_), 92 errors::InvalidArgument( 93 "Only one of adjoint_b and transpose_b may be true.")); 94 OP_REQUIRES_OK(c, c->GetAttr("transpose_output", &transpose_output_)); 95 OP_REQUIRES_OK(c, c->GetAttr("conjugate_output", &conjugate_output_)); 96 conjugate_a_ = adjoint_a; 97 conjugate_b_ = adjoint_b; 98 transpose_a_ |= adjoint_a; 99 transpose_b_ |= adjoint_b; 100 } 101 ~CSRMatMulOp()102 ~CSRMatMulOp() override {} 103 ValidateInputs(const CSRSparseMatrix & sparse_matrix_a,const Tensor & dense_tensor_b,int * rank,int64 * batch_size)104 Status ValidateInputs(const CSRSparseMatrix& sparse_matrix_a, 105 const Tensor& dense_tensor_b, int* rank, 106 int64* batch_size) { 107 if (sparse_matrix_a.dtype() != dense_tensor_b.dtype()) { 108 return errors::InvalidArgument( 109 "Input types don't match. a.dtype == ", 110 DataTypeString(sparse_matrix_a.dtype()), 111 " vs. b.dtype == ", DataTypeString(dense_tensor_b.dtype())); 112 } 113 *rank = sparse_matrix_a.dims(); 114 // TODO(ebrevdo): Add support for broadcasting matmul. 115 if (*rank != dense_tensor_b.dims()) { 116 return errors::InvalidArgument("Ranks of a and b must match, saw: ", rank, 117 " vs. ", dense_tensor_b.dims(), "."); 118 } 119 // A valid CSR SparseMatrix has rank 2 or rank 3. 120 *batch_size = (*rank == 2) ? 1 : dense_tensor_b.dim_size(0); 121 if (sparse_matrix_a.batch_size() != *batch_size) { 122 return errors::InvalidArgument("Batch sizes of a and b must match, saw: ", 123 sparse_matrix_a.batch_size(), " vs. ", 124 batch_size, "."); 125 } 126 const auto& a_dense_shape = sparse_matrix_a.dense_shape().vec<int64>(); 127 const int64_t a_inner_dim = 128 a_dense_shape(this->transpose_a_ ? *rank - 2 : *rank - 1); 129 const int64_t b_inner_dim = 130 dense_tensor_b.dim_size(this->transpose_b_ ? *rank - 1 : *rank - 2); 131 if (a_inner_dim != b_inner_dim) { 132 return errors::InvalidArgument( 133 "Inner product dimensions of A and B do not agree. Shapes are: ", 134 TensorShape(a_dense_shape), " vs. ", 135 dense_tensor_b.shape().DebugString()); 136 } 137 return Status::OK(); 138 } 139 140 public: 141 bool transpose_a_; 142 bool transpose_b_; 143 bool conjugate_a_; 144 bool conjugate_b_; 145 bool transpose_output_; 146 bool conjugate_output_; 147 }; 148 149 // CPU Kernel to compute sparse-dense matrix multiplication. 150 // 151 // Uses Eigen SparseMatrix to compute the sparse-dense multiplication between 152 // a CSR SparseMatrix `a` and dense Tensor `b`. If intra-op parallelism is 153 // available, the implementation parallelizes the computation across each row 154 // of the sparse matrix. 155 template <typename T> 156 class CSRMatMulCPUOp : public CSRMatMulOp<CPUDevice, T> { 157 using SparseMatrix = Eigen::SparseMatrix<T, Eigen::RowMajor>; 158 using Matrix = 159 Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>; 160 using ConstMatrixMap = Eigen::Map<const Matrix>; 161 using MatrixMap = Eigen::Map<Matrix>; 162 163 public: CSRMatMulCPUOp(OpKernelConstruction * c)164 explicit CSRMatMulCPUOp(OpKernelConstruction* c) 165 : CSRMatMulOp<CPUDevice, T>(c) {} 166 ~CSRMatMulCPUOp()167 ~CSRMatMulCPUOp() override {} 168 Compute(OpKernelContext * ctx)169 void Compute(OpKernelContext* ctx) final { 170 const CSRSparseMatrix* sparse_matrix_a; 171 OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &sparse_matrix_a)); 172 const Tensor& matrix_b = ctx->input(1); 173 174 int rank; 175 int64_t batch_size; 176 OP_REQUIRES_OK(ctx, this->ValidateInputs(*sparse_matrix_a, matrix_b, &rank, 177 &batch_size)); 178 179 const auto dense_shape = sparse_matrix_a->dense_shape().vec<int64>(); 180 int64_t num_lhs_rows = dense_shape(rank - 2); 181 int64_t num_lhs_cols = dense_shape(rank - 1); 182 int64_t num_rhs_rows = matrix_b.dim_size(rank - 2); 183 int64_t num_rhs_cols = matrix_b.dim_size(rank - 1); 184 185 if (this->transpose_a_) { 186 std::swap(num_lhs_rows, num_lhs_cols); 187 } 188 189 // Possibly transpose the dense Tensor b. 190 const Tensor* rhs = &matrix_b; 191 Tensor b_transposed; 192 if (this->transpose_b_) { 193 OP_REQUIRES_OK( 194 ctx, TransposeAndConjugateTensor(ctx, matrix_b, this->conjugate_b_, 195 &b_transposed)); 196 rhs = &b_transposed; 197 std::swap(num_rhs_rows, num_rhs_cols); 198 } 199 200 // If we're transposing the output, then allocate a temporary buffer to 201 // store the output. Otherwise allocate the output directly. 202 Tensor* output = nullptr; 203 Tensor* matmul_result = nullptr; 204 Tensor output_transposed; 205 OP_REQUIRES_OK( 206 ctx, AllocateOutput(ctx, rank, batch_size, num_lhs_rows, num_rhs_cols, 207 this->transpose_output_, &output, 208 &output_transposed, &matmul_result)); 209 210 if (!this->transpose_a_) { 211 SparseDenseMatMulWithoutTransposedLHS( 212 ctx, batch_size, num_lhs_rows, *sparse_matrix_a, *rhs, matmul_result); 213 } else { // transpose_a_ == true 214 SparseDenseMatMulWithTransposedLHS(ctx, batch_size, num_lhs_rows, 215 num_lhs_cols, *sparse_matrix_a, *rhs, 216 matmul_result); 217 } 218 219 // Transpose (and conjugate) the output if necessary. 220 // Note that conjugate is only true if transpose is also true. 221 if (this->transpose_output_) { 222 OP_REQUIRES_OK( 223 ctx, TransposeAndConjugateAllocatedTensor( 224 ctx, output_transposed, this->conjugate_output_, output)); 225 } else if (this->conjugate_output_) { 226 functor::maybe_conj_inplace<CPUDevice, T>::run( 227 ctx->eigen_device<CPUDevice>(), output); 228 } 229 } 230 231 private: 232 // Allocates the output with the appropriate shape. Additionally, if 233 // transpose_output is True, allocates a temporary buffer with the transposed 234 // output. 'matmul_result' points to either output or output_transposed, based 235 // on whether transpose_output is True. AllocateOutput(OpKernelContext * ctx,const int32_t rank,const int64_t batch_size,const int64_t num_rows,const int64_t num_cols,const bool transpose_output,Tensor ** output,Tensor * output_transposed,Tensor ** matmul_result)236 Status AllocateOutput(OpKernelContext* ctx, const int32_t rank, 237 const int64_t batch_size, const int64_t num_rows, 238 const int64_t num_cols, const bool transpose_output, 239 Tensor** output, Tensor* output_transposed, 240 Tensor** matmul_result) { 241 TensorShape output_shape; 242 if (rank == 3) output_shape.AddDim(batch_size); 243 244 if (!transpose_output) { 245 output_shape.AppendShape({num_rows, num_cols}); 246 TF_RETURN_IF_ERROR(ctx->allocate_output(0, output_shape, output)); 247 *matmul_result = *output; 248 } else { 249 TensorShape output_transposed_shape = output_shape; 250 output_transposed_shape.AppendShape({num_rows, num_cols}); 251 output_shape.AppendShape({num_cols, num_rows}); 252 TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value, 253 output_transposed_shape, 254 output_transposed)); 255 TF_RETURN_IF_ERROR(ctx->allocate_output(0, output_shape, output)); 256 *matmul_result = output_transposed; 257 } 258 return Status::OK(); 259 } 260 261 // Returns an Eigen::Ref expression of a sparse sub-matrix from the given 262 // contiguous segment of rows of the CSR Sparse Matrix. GetSparseMatrixRef(const CSRSparseMatrix & csr_matrix,const int batch_index,const int64_t row_begin,const int64_t num_shard_rows,std::vector<int32> * row_ptrs)263 Eigen::Ref<const SparseMatrix> GetSparseMatrixRef( 264 const CSRSparseMatrix& csr_matrix, const int batch_index, 265 const int64_t row_begin, const int64_t num_shard_rows, 266 std::vector<int32>* row_ptrs) { 267 // Compute the row pointers of the sparse sub-matrix. 268 row_ptrs->resize(num_shard_rows + 1); 269 const int64_t row_offset = 270 csr_matrix.row_pointers_vec(batch_index)(row_begin); 271 for (int64_t row_idx = 0; row_idx <= num_shard_rows; ++row_idx) { 272 row_ptrs->at(row_idx) = 273 csr_matrix.row_pointers_vec(batch_index)(row_begin + row_idx) - 274 row_offset; 275 } 276 const int64_t num_cols = 277 csr_matrix.dense_shape().vec<int64>()(csr_matrix.dims() - 1); 278 return Eigen::Map<const SparseMatrix>( 279 num_shard_rows /* num_rows */, num_cols /* num_cols */, 280 row_ptrs->at(num_shard_rows) /* total_nnz */, row_ptrs->data(), 281 csr_matrix.col_indices_vec(batch_index).data() + row_offset, 282 csr_matrix.values_vec<T>(batch_index).data() + row_offset); 283 } 284 285 // Sparse-Dense Matrix Multiplication between a CSRSparseMatrix (LHS) and a 286 // dense Tensor (RHS). SparseDenseMatMulWithoutTransposedLHS(OpKernelContext * ctx,const int64_t batch_size,const int64_t num_lhs_rows,const CSRSparseMatrix & lhs,const Tensor & rhs,Tensor * output)287 void SparseDenseMatMulWithoutTransposedLHS(OpKernelContext* ctx, 288 const int64_t batch_size, 289 const int64_t num_lhs_rows, 290 const CSRSparseMatrix& lhs, 291 const Tensor& rhs, 292 Tensor* output) { 293 // Parallelize matrix multiplication across batch dimensions and across 294 // rows in each batch. 295 auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); 296 const int32_t num_threads = worker_threads.num_threads; 297 const int64_t block_size = 298 num_lhs_rows / std::max(kMaxShards, kNumShardsPerThread * num_threads); 299 const int64_t num_rhs_rows = rhs.dim_size(rhs.dims() - 2); 300 const int64_t num_rhs_cols = rhs.dim_size(rhs.dims() - 1); 301 worker_threads.workers->ParallelFor( 302 batch_size * num_lhs_rows /* total */, 303 thread::ThreadPool::SchedulingParams( 304 thread::ThreadPool::SchedulingStrategy:: 305 kFixedBlockSize /* strategy */, 306 absl::nullopt /* cost_per_unit */, block_size), 307 [&](int64_t batch_and_row_begin, int64_t batch_and_row_end) { 308 HandleBatchAndRowRange( 309 num_lhs_rows, batch_and_row_begin, batch_and_row_end, 310 [&](int64_t batch_idx, int64_t row_begin, int64_t row_end) { 311 const int64_t num_shard_rows = row_end - row_begin; 312 313 // Define an Eigen::SparseMatrix over the row range: 314 // [row_begin, row_end) of the CSR SparseMatrix A. 315 std::vector<int32> row_ptrs; 316 auto sparse_matrix = GetSparseMatrixRef( 317 lhs, batch_idx, row_begin, num_shard_rows, &row_ptrs); 318 319 // Map the corresponding rows of the rhs. 320 ConstMatrixMap rhs_map(rhs.flat<T>().data() + batch_idx * 321 num_rhs_rows * 322 num_rhs_cols, 323 num_rhs_rows, num_rhs_cols); 324 325 // Write to the corresponding rows of the output matrix. 326 MatrixMap output_map( 327 output->flat<T>().data() + 328 batch_idx * num_lhs_rows * num_rhs_cols + 329 row_begin * num_rhs_cols, 330 num_shard_rows, num_rhs_cols); 331 output_map.noalias() = sparse_matrix * rhs_map; 332 }); 333 }); 334 } 335 336 // Sparse-Dense Matrix Multiplication assuming the CSRSparseMatrix (LHS) is 337 // to be transposed before the operation. SparseDenseMatMulWithTransposedLHS(OpKernelContext * ctx,const int64_t batch_size,const int64_t num_lhs_rows,const int64_t num_lhs_cols,const CSRSparseMatrix & lhs,const Tensor & rhs,Tensor * output)338 void SparseDenseMatMulWithTransposedLHS(OpKernelContext* ctx, 339 const int64_t batch_size, 340 const int64_t num_lhs_rows, 341 const int64_t num_lhs_cols, 342 const CSRSparseMatrix& lhs, 343 const Tensor& rhs, Tensor* output) { 344 auto device = ctx->eigen_device<CPUDevice>(); 345 auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); 346 const int32_t num_threads = worker_threads.num_threads; 347 const int64_t num_rhs_rows = rhs.dim_size(rhs.dims() - 2); 348 const int64_t num_rhs_cols = rhs.dim_size(rhs.dims() - 1); 349 // Usually, we want to avoid transposing the sparse matrix A since it may be 350 // an expensive operation. Instead, we use the identity (A^T B) = (B^T A)^T. 351 // We don't actually transpose B or the output because it is more convenient 352 // to have them in column major form. 353 // 354 // However, if A is hypersparse and B and C are huge, transposing A will be 355 // cheaper. In the future, we should have a cost model estimating the cost 356 // of transposing all matrices (A, B, C) to decide which variant to use. 357 358 // Each thread writes to its own copy of the matrix product. These 359 // `num_threads` copies are summed together to obtain the final result. 360 Tensor matmul_result_buffer; 361 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, 362 TensorShape({num_threads + 1, 363 output->NumElements()}), 364 &matmul_result_buffer)); 365 functor::SetZeroFunctor<CPUDevice, T> set_zero; 366 set_zero(device, matmul_result_buffer.flat<T>()); 367 368 // Parallelize matrix multiplication across batch dimensions and across 369 // columns of A^T in each batch. These correspond to rows of A. 370 const int64_t block_size = 371 num_lhs_cols / std::max(kMaxShards, kNumShardsPerThread * num_threads); 372 worker_threads.workers->ParallelForWithWorkerId( 373 batch_size * num_lhs_cols /* total */, 374 thread::ThreadPool::SchedulingParams( 375 thread::ThreadPool::SchedulingStrategy:: 376 kFixedBlockSize /* strategy */, 377 absl::nullopt /* cost_per_unit */, block_size), 378 [&](int64_t batch_and_row_begin, int64_t batch_and_row_end, int tid) { 379 HandleBatchAndRowRange( 380 num_lhs_cols, batch_and_row_begin, batch_and_row_end, 381 [&](int64_t batch_idx, int64_t row_begin, int64_t row_end) { 382 const int64_t num_shard_rows = row_end - row_begin; 383 384 // Define a new sparse sub-matrix from the row range 385 // [row_begin, row_end) of the sparse matrix A. 386 std::vector<int32> row_ptrs; 387 auto sparse_matrix = GetSparseMatrixRef( 388 lhs, batch_idx, row_begin, num_shard_rows, &row_ptrs); 389 390 // Map the corresponding `num_shard_rows` columns of B^T. 391 // This is the same as taking the `num_shard_rows` rows of B. 392 ConstMatrixMap b_dense_map( 393 rhs.flat<T>().data() + 394 batch_idx * num_rhs_rows * num_rhs_cols + 395 row_begin * num_rhs_cols, 396 num_shard_rows, num_rhs_cols); 397 398 // Map to the corresponding rows of the output. 399 MatrixMap output_map( 400 matmul_result_buffer.flat<T>().data() + 401 tid * batch_size * num_lhs_rows * num_rhs_cols + 402 batch_idx * num_lhs_rows * num_rhs_cols, 403 num_lhs_rows, num_rhs_cols); 404 405 // Compute the product C^T = B^T * A; restricted to the row 406 // range in the current shard. 407 if (this->conjugate_a_) { 408 output_map.transpose().noalias() += 409 b_dense_map.transpose() * sparse_matrix.conjugate(); 410 } else { 411 output_map.transpose().noalias() += 412 b_dense_map.transpose() * sparse_matrix; 413 } 414 }); 415 }); 416 417 // Sum across each thread's matmul result. 418 using Reducer = Eigen::internal::SumReducer<T>; 419 using Index = typename TTypes<T>::Tensor::Index; 420 output->flat<T>().device(device) = matmul_result_buffer.matrix<T>().reduce( 421 Eigen::array<Index, 1>({0}), Reducer()); 422 } 423 424 // Given a range [batch_and_row_begin, batch_and_row_end) which is a 425 // contiguous subset of [0, num_rows * batch_size), calls the function 426 // fn(batch_idx, row_begin, row_end) for each batch index 427 // and the row range [row_begin, row_end) contained in the batch. HandleBatchAndRowRange(const int64_t num_rows,const int64_t batch_and_row_begin,const int64_t batch_and_row_end,const std::function<void (int64_t,int64_t,int64_t)> & fn)428 void HandleBatchAndRowRange( 429 const int64_t num_rows, const int64_t batch_and_row_begin, 430 const int64_t batch_and_row_end, 431 const std::function<void(int64_t, int64_t, int64_t)>& fn) { 432 // Obtain the batch indices overlapping with the current shard. 433 const int64_t batch_begin = batch_and_row_begin / num_rows; 434 const int64_t batch_end_inclusive = batch_and_row_end / num_rows; 435 436 for (int64_t batch_idx = batch_begin; batch_idx <= batch_end_inclusive; 437 ++batch_idx) { 438 // Find the contiguous set of rows which are contained in this shard as 439 // well as the current batch. We intersect with interval [batch_idx * 440 // num_rows, (batch_idx + 1) * num_rows) which denotes the current batch. 441 const int64_t current_batch_row_begin = 442 std::max(batch_and_row_begin, batch_idx * num_rows); 443 const int64_t current_batch_row_end = 444 std::min(batch_and_row_end, (batch_idx + 1) * num_rows); 445 446 const int64_t row_begin = current_batch_row_begin % num_rows; 447 const int64_t num_shard_rows = 448 current_batch_row_end - current_batch_row_begin; 449 // Edge case for when current_batch_row_end is the first index of a new 450 // row. 451 if (num_shard_rows == 0) continue; 452 453 fn(batch_idx, row_begin, row_begin + num_shard_rows); 454 } 455 } 456 457 // Transposes (and optionally, conjugates) a given Tensor. Also allocates the 458 // required memory for the output Tensor. TransposeAndConjugateTensor(OpKernelContext * ctx,const Tensor & input,bool conjugate,Tensor * output)459 Status TransposeAndConjugateTensor(OpKernelContext* ctx, const Tensor& input, 460 bool conjugate, Tensor* output) { 461 TensorShape transposed_shape = input.shape(); 462 transposed_shape.set_dim(input.dims() - 1, 463 input.dim_size(input.dims() - 2)); 464 transposed_shape.set_dim(input.dims() - 2, 465 input.dim_size(input.dims() - 1)); 466 TF_RETURN_IF_ERROR( 467 ctx->allocate_temp(DataTypeToEnum<T>::value, transposed_shape, output)); 468 return TransposeAndConjugateAllocatedTensor(ctx, input, conjugate, output); 469 } 470 471 // Transposes (and optionally, conjugates) a given Tensor. The output should 472 // be already allocated. TransposeAndConjugateAllocatedTensor(OpKernelContext * ctx,const Tensor & input,bool conjugate,Tensor * output)473 Status TransposeAndConjugateAllocatedTensor(OpKernelContext* ctx, 474 const Tensor& input, 475 bool conjugate, Tensor* output) { 476 if (conjugate) { 477 TF_RETURN_IF_ERROR(DoConjugateMatrixTranspose( 478 ctx->eigen_device<CPUDevice>(), input, output)); 479 } else { 480 TF_RETURN_IF_ERROR( 481 DoMatrixTranspose(ctx->eigen_device<CPUDevice>(), input, output)); 482 } 483 return Status::OK(); 484 } 485 }; 486 487 // GPU Kernel to compute sparse-dense matrix multiplication. 488 template <typename T> 489 class CSRMatMulGPUOp : public CSRMatMulOp<GPUDevice, T> { 490 using SparseMatrix = Eigen::SparseMatrix<T, Eigen::RowMajor>; 491 using Matrix = 492 Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>; 493 using ConstMatrixMap = Eigen::Map<const Matrix>; 494 using MatrixMap = Eigen::Map<Matrix>; 495 496 public: CSRMatMulGPUOp(OpKernelConstruction * c)497 explicit CSRMatMulGPUOp(OpKernelConstruction* c) 498 : CSRMatMulOp<GPUDevice, T>(c) {} 499 ~CSRMatMulGPUOp()500 ~CSRMatMulGPUOp() override {} 501 Compute(OpKernelContext * ctx)502 void Compute(OpKernelContext* ctx) final { 503 const CSRSparseMatrix* a_matrix; 504 OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &a_matrix)); 505 const Tensor& b_t = ctx->input(1); 506 507 int rank; 508 int64_t batch_size; 509 OP_REQUIRES_OK(ctx, 510 this->ValidateInputs(*a_matrix, b_t, &rank, &batch_size)); 511 512 const Tensor& a_dense_shape_t = a_matrix->dense_shape(); 513 TensorShape a_dense_tensor_shape; 514 auto a_dense_shape = a_dense_shape_t.vec<int64>(); 515 OP_REQUIRES_OK( 516 ctx, TensorShapeUtils::MakeShape(a_dense_shape, &a_dense_tensor_shape)); 517 518 const int row_dim = (rank == 2) ? 0 : 1; 519 const int64_t a_outer_dim = a_dense_tensor_shape.dim_size( 520 this->transpose_a_ ? row_dim + 1 : row_dim); 521 const int64_t b_inner_dim = 522 b_t.shape().dim_size(this->transpose_b_ ? row_dim + 1 : row_dim); 523 const int64_t b_outer_dim = 524 b_t.dim_size(this->transpose_b_ ? row_dim : row_dim + 1); 525 const int64_t b_slice_size = b_inner_dim * b_outer_dim; 526 527 TensorShape c_shape; 528 if (rank == 3) c_shape.AddDim(batch_size); 529 if (this->transpose_output_) { 530 c_shape.AddDim(b_outer_dim); 531 c_shape.AddDim(a_outer_dim); 532 } else { 533 c_shape.AddDim(a_outer_dim); 534 c_shape.AddDim(b_outer_dim); 535 } 536 537 const int64_t c_matrix_lhs = c_shape.dim_size(row_dim); 538 const int64_t c_matrix_rhs = c_shape.dim_size(row_dim + 1); 539 const int64_t c_slice_size = c_matrix_lhs * c_matrix_rhs; 540 Tensor* c_t; 541 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, c_shape, &c_t)); 542 543 const GPUDevice& d = ctx->eigen_device<GPUDevice>(); 544 bool use_matrix_vector_multiply = (b_outer_dim == 1); 545 #if TENSORFLOW_USE_ROCM 546 // ROCm hipsparse does not implement csrmv with transposed input a 547 use_matrix_vector_multiply = 548 use_matrix_vector_multiply && !this->transpose_a_; 549 #endif 550 if (use_matrix_vector_multiply) { 551 // Call matrix-vector multiply if b is a vector. 552 TTypes<int64>::ConstVec a_dense_shape_comp(a_dense_shape.data() + row_dim, 553 2); 554 Tensor b_conj_t; 555 const T* b_base_ptr = b_t.template flat<T>().data(); 556 bool conjugate_a = this->conjugate_a_; 557 bool conjugate_output = this->conjugate_output_; 558 if (this->conjugate_b_) { 559 if (conjugate_a) { 560 // In this case we can use the identity 561 // conj(a) * conj(b) = conj(a * b) 562 // instead of creating a conjugated copy of b. 563 conjugate_a = false; 564 conjugate_output = !conjugate_output; 565 } else { 566 OP_REQUIRES_OK( 567 ctx, ctx->forward_input_or_allocate_temp( 568 {1}, DataTypeToEnum<T>::value, b_t.shape(), &b_conj_t)); 569 functor::maybe_conj<GPUDevice, T>::run(d, b_t, &b_conj_t); 570 b_base_ptr = b_conj_t.template flat<T>().data(); 571 } 572 } 573 574 functor::CSRSparseMatrixMatVec<GPUDevice, T> csr_spmv(this->transpose_a_, 575 conjugate_a); 576 for (int i = 0; i < batch_size; ++i) { 577 auto a_row_ptr = a_matrix->row_pointers_vec(i); 578 auto a_col_ind = a_matrix->col_indices_vec(i); 579 auto a_values = a_matrix->values_vec<T>(i); 580 ConstCSRComponent<T> a_comp{a_row_ptr, a_col_ind, a_values, 581 a_dense_shape_comp}; 582 const T* b_i = b_base_ptr + i * b_slice_size; 583 T* c_i = &c_t->template flat<T>()(i * c_slice_size); 584 Status s = csr_spmv.Compute(ctx, a_comp, b_i, c_i); 585 OP_REQUIRES_OK(ctx, s); 586 } 587 if (conjugate_output) { 588 functor::maybe_conj_inplace<GPUDevice, T>::run(d, c_t); 589 } 590 return; 591 } 592 593 functor::CSRSparseMatrixMatMul<GPUDevice, T> csr_spmmadd( 594 this->transpose_output_); 595 596 Tensor c_mat_col_major_t; 597 if (!this->transpose_output_) { 598 // If transpose_output is false, we'll need to transpose the (col 599 // major) output of the csrgemm call to get proper (row-major) 600 // output. Which means we need to keep a temporary buffer to 601 // store the intermediate gemm output. 602 TensorShape c_mat_col_major_shape; 603 if (rank == 2) { 604 c_mat_col_major_shape = TensorShape({c_matrix_rhs, c_matrix_lhs}); 605 } else { 606 c_mat_col_major_shape = 607 TensorShape({batch_size, c_matrix_rhs, c_matrix_lhs}); 608 } 609 OP_REQUIRES_OK( 610 ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, 611 c_mat_col_major_shape, &c_mat_col_major_t)); 612 } 613 614 // If transpose_output is true, return the direct (column-major i.e., 615 // transposed) output of the csrgemm call. Otherwise we'll need 616 // to transpose it to row major format. 617 auto c_mat_col_major = (this->transpose_output_) 618 ? c_t->flat<T>() 619 : c_mat_col_major_t.flat<T>(); 620 621 // Possibly transpose a. 622 const CSRSparseMatrix* a_input_matrix; 623 // If we need to transpose a, we will store the result temporarily 624 // in the object below. 625 CSRSparseMatrix a_matrix_transposed; 626 if (!this->transpose_a_) { 627 a_input_matrix = a_matrix; 628 } else { 629 functor::CSRSparseMatrixTranspose<GPUDevice, T> transpose; 630 OP_REQUIRES_OK(ctx, transpose(ctx, this->conjugate_a_, *a_matrix, 631 &a_matrix_transposed)); 632 a_input_matrix = &a_matrix_transposed; 633 } 634 635 auto a_input_dense_shape = a_input_matrix->dense_shape().vec<int64>(); 636 637 // Possibly transpose b. 638 Tensor b_t_input; 639 if (!this->transpose_b_) { 640 b_t_input = b_t; 641 } else { 642 TensorShape b_t_transposed_shape; 643 if (rank == 3) { 644 b_t_transposed_shape.AddDim(batch_size); 645 } 646 b_t_transposed_shape.AddDim(b_t.dim_size(row_dim + 1)); 647 b_t_transposed_shape.AddDim(b_t.dim_size(row_dim)); 648 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, 649 b_t_transposed_shape, &b_t_input)); 650 const GPUDevice& d = ctx->eigen_device<GPUDevice>(); 651 if (this->conjugate_b_) { 652 OP_REQUIRES_OK(ctx, DoConjugateMatrixTranspose(d, b_t /*input*/, 653 &b_t_input /*output*/)); 654 } else { 655 OP_REQUIRES_OK( 656 ctx, DoMatrixTranspose(d, b_t /*input*/, &b_t_input /*output*/)); 657 } 658 } 659 660 // Dense shape of a batch component of A. 661 TTypes<int64>::ConstVec a_input_dense_shape_comp( 662 a_input_dense_shape.data() + row_dim, 2); 663 664 auto b = b_t_input.flat<T>(); 665 666 for (int i = 0; i < batch_size; ++i) { 667 auto a_row_ptr = a_input_matrix->row_pointers_vec(i); 668 auto a_col_ind = a_input_matrix->col_indices_vec(i); 669 auto a_values = a_input_matrix->values_vec<T>(i); 670 typename TTypes<T>::UnalignedConstMatrix b_i(b.data() + i * b_slice_size, 671 {b_inner_dim, b_outer_dim}); 672 typename TTypes<T>::UnalignedMatrix c_mat_col_major_i( 673 c_mat_col_major.data() + i * c_slice_size, 674 {c_matrix_lhs, c_matrix_rhs}); 675 ConstCSRComponent<T> a_comp{a_row_ptr, a_col_ind, a_values, 676 a_input_dense_shape_comp}; 677 Status s = csr_spmmadd.Compute(ctx, a_comp, b_i, c_mat_col_major_i); 678 OP_REQUIRES_OK(ctx, s); 679 } 680 681 if (!this->transpose_output_) { 682 // We need to return values in row major format, so transpose 683 // the column-major values in c_mat_col_major_t to row-major output c_t. 684 OP_REQUIRES_OK(ctx, DoMatrixTranspose(d, /*input=*/c_mat_col_major_t, 685 /*output=*/c_t)); 686 } 687 if (this->conjugate_output_) { 688 functor::maybe_conj_inplace<GPUDevice, T>::run(d, c_t); 689 } 690 } 691 }; 692 693 #define REGISTER_CPU(T) \ 694 REGISTER_KERNEL_BUILDER( \ 695 Name("SparseMatrixMatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 696 CSRMatMulCPUOp<T>); 697 698 REGISTER_CPU(float) 699 REGISTER_CPU(double) 700 REGISTER_CPU(complex64) 701 REGISTER_CPU(complex128) 702 703 #undef REGISTER_CPU 704 705 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 706 707 #define REGISTER_GPU(T) \ 708 REGISTER_KERNEL_BUILDER( \ 709 Name("SparseMatrixMatMul").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ 710 CSRMatMulGPUOp<T>); 711 712 REGISTER_GPU(float) 713 REGISTER_GPU(double) 714 REGISTER_GPU(complex64) 715 REGISTER_GPU(complex128) 716 717 #undef REGISTER_GPU 718 719 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 720 721 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 722 723 namespace functor { 724 725 namespace { 726 727 // GPUDataType<T>::type translates from a C++ type (e.g. float) to a 728 // GPUDataType_t (e.g. CUDA_R_32F). 729 template <typename T> 730 struct GPUDataType; 731 732 // GPUDataType templates are currently not instantiated in the ROCm flow 733 // So leaving out the #elif TENSORFLOW_USE_ROCM blocks for now 734 // hipblas library is not (yet) being pulled in via rocm_configure.bzl 735 // so cannot reference tyeps from hipblas headers here 736 template <> 737 struct GPUDataType<Eigen::half> { 738 #if GOOGLE_CUDA 739 static constexpr cudaDataType_t type = CUDA_R_16F; 740 #endif 741 }; 742 743 template <> 744 struct GPUDataType<float> { 745 #if GOOGLE_CUDA 746 static constexpr cudaDataType_t type = CUDA_R_32F; 747 #endif 748 }; 749 750 template <> 751 struct GPUDataType<std::complex<float>> { 752 #if GOOGLE_CUDA 753 static constexpr cudaDataType_t type = CUDA_C_32F; 754 #endif 755 }; 756 757 template <> 758 struct GPUDataType<double> { 759 #if GOOGLE_CUDA 760 static constexpr cudaDataType_t type = CUDA_R_64F; 761 #endif 762 }; 763 764 template <> 765 struct GPUDataType<std::complex<double>> { 766 #if GOOGLE_CUDA 767 static constexpr cudaDataType_t type = CUDA_C_64F; 768 #endif 769 }; 770 771 } // namespace 772 773 template <typename T> 774 class CSRSparseMatrixMatMul<GPUDevice, T> { 775 public: CSRSparseMatrixMatMul(const bool transpose_output)776 explicit CSRSparseMatrixMatMul(const bool transpose_output) 777 : transpose_output_(transpose_output) {} 778 Compute(OpKernelContext * ctx,const ConstCSRComponent<T> & a,typename TTypes<T>::UnalignedConstMatrix b,typename TTypes<T>::UnalignedMatrix c)779 Status Compute(OpKernelContext* ctx, const ConstCSRComponent<T>& a, 780 typename TTypes<T>::UnalignedConstMatrix b, 781 typename TTypes<T>::UnalignedMatrix c) { 782 GpuSparse cuda_sparse(ctx); 783 TF_RETURN_IF_ERROR(cuda_sparse.Initialize()); 784 { 785 // Use Csrmm/SpMM to calculate: 786 // C = alpha * op(A) * op(B) + beta * C 787 // where alpha = 1.0, beta = 0.0, A is sparse and B and C are dense. 788 // Note that Csrmm/Spmm assumes B and C are in column-major form; so we 789 // use transB == true, and manually transpose the output in place 790 // using blas<t>geam. 791 // TODO(ebrevdo,rmlarsen): Add support for transposition and adjoint. 792 793 // Create alpha and beta scalars; alpha = 1.0, beta = 0.0 794 // TODO(ebrevdo,rmlarsen): Add support for non-trivial alpha and beta. 795 const T alpha = 1; 796 const T beta = 0; 797 798 // A is (m, k), Bt is (ldb, k) and Ct is (ldc, n) 799 const int k = b.dimension(0); 800 DCHECK_EQ(k, a.dense_shape_host(1)); 801 802 // If transpose_output_ is true, then the c matrix we receive 803 // here is the direct row major output (into which we will store 804 // csrgemm's col major output). Otherwise it's a 805 // temporary tensor that will store the column major output that 806 // will eventually be transposed. 807 const int m = c.dimension(transpose_output_ ? 1 : 0); 808 const int n = c.dimension(transpose_output_ ? 0 : 1); 809 DCHECK_EQ(m, a.dense_shape_host(0)); 810 DCHECK_EQ(n, b.dimension(1)); 811 const int nnz = a.values.size(); 812 DCHECK_EQ(nnz, a.col_ind.size()); 813 814 // ldb: leading dimension of B. If op(B)=B, it must be at least max(1, k) 815 // if op(A) = A and at least max (1, m) otherwise. If op(B) != B, it must 816 // be at least max(1, n). 817 const int ldb = n; 818 // ldc: leading dimension of C. It must be at least max(1, m) if 819 // op(A) = A and at least max(1, k) otherwise. 820 const int ldc = m; 821 822 // transA must be non-transpose if transB is transpose (cusparse 823 // limitation). 824 #if GOOGLE_CUDA 825 const gpusparseOperation_t transA = CUSPARSE_OPERATION_NON_TRANSPOSE; 826 #elif TENSORFLOW_USE_ROCM 827 const gpusparseOperation_t transA = HIPSPARSE_OPERATION_NON_TRANSPOSE; 828 #endif 829 830 // transB: b is row-major, and cusparse requires col-major b (or 831 // equivalently transB == transpose). this version is actually more 832 // efficient. 833 #if GOOGLE_CUDA && CUDA_VERSION >= 10020 834 835 const gpusparseOperation_t transB = CUSPARSE_OPERATION_TRANSPOSE; 836 gpusparseSpMatDescr_t matA; 837 gpusparseDnMatDescr_t matB, matC; 838 839 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsr( 840 &matA, m, k, nnz, const_cast<int*>(a.row_ptr.data()), 841 const_cast<int*>(a.col_ind.data()), const_cast<T*>(a.values.data()), 842 CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, 843 GPUDataType<T>::type)); 844 845 TF_RETURN_IF_GPUSPARSE_ERROR( 846 cusparseCreateDnMat(&matB, n, k, ldb, const_cast<T*>(b.data()), 847 GPUDataType<T>::type, CUSPARSE_ORDER_COL)); 848 849 TF_RETURN_IF_GPUSPARSE_ERROR( 850 cusparseCreateDnMat(&matC, m, n, ldc, c.data(), GPUDataType<T>::type, 851 CUSPARSE_ORDER_COL)); 852 853 size_t bufferSize = 0; 854 TF_RETURN_IF_ERROR(cuda_sparse.SpMMBufferSize( 855 transA, transB, &alpha, matA, matB, &beta, matC, 856 CUSPARSE_MM_ALG_DEFAULT, &bufferSize)); 857 858 Tensor buffer; 859 TF_RETURN_IF_ERROR(ctx->allocate_temp( 860 DT_INT8, TensorShape({static_cast<int64>(bufferSize)}), &buffer)); 861 DCHECK(buffer.flat<int8>().data() != nullptr); 862 863 TF_RETURN_IF_ERROR(cuda_sparse.SpMM(transA, transB, &alpha, matA, matB, 864 &beta, matC, CUSPARSE_MM_ALG_DEFAULT, 865 buffer.flat<int8>().data())); 866 867 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnMat(matB)); 868 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnMat(matC)); 869 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroySpMat(matA)); 870 871 #elif TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 40200 872 // Use SPMM 873 const gpusparseOperation_t transB = HIPSPARSE_OPERATION_TRANSPOSE; 874 gpusparseSpMatDescr_t matA; 875 gpusparseDnMatDescr_t matB, matC; 876 877 TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateCsr( 878 &matA, m, k, nnz, const_cast<int*>(a.row_ptr.data()), 879 const_cast<int*>(a.col_ind.data()), const_cast<T*>(a.values.data()), 880 CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO, 881 GPUDataType<T>::type)); 882 883 TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateDnMat( 884 &matB, n, k, ldb, const_cast<T*>(b.data()), GPUDataType<T>::type, 885 HIPSPARSE_ORDER_COL)); 886 887 TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateDnMat( 888 &matC, m, n, ldc, c.data(), GPUDataType<T>::type, 889 HIPSPARSE_ORDER_COL)); 890 891 size_t bufferSize = 0; 892 TF_RETURN_IF_ERROR(cuda_sparse.SpMMBufferSize( 893 transA, transB, &alpha, matA, matB, &beta, matC, 894 HIPSPARSE_MM_ALG_DEFAULT, &bufferSize)); 895 896 Tensor buffer; 897 TF_RETURN_IF_ERROR(ctx->allocate_temp( 898 DT_INT8, TensorShape({static_cast<int64>(bufferSize)}), &buffer)); 899 DCHECK(buffer.flat<int8>().data() != nullptr); 900 901 TF_RETURN_IF_ERROR(cuda_sparse.SpMM(transA, transB, &alpha, matA, matB, 902 &beta, matC, CUSPARSE_MM_ALG_DEFAULT, 903 buffer.flat<int8>().data())); 904 905 TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseDestroyDnMat(matB)); 906 TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseDestroyDnMat(matC)); 907 TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseDestroySpMat(matA)); 908 909 #else 910 911 #if GOOGLE_CUDA 912 913 const gpusparseOperation_t transB = CUSPARSE_OPERATION_TRANSPOSE; 914 915 gpusparseMatDescr_t descrA; 916 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA)); 917 TF_RETURN_IF_GPUSPARSE_ERROR( 918 cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL)); 919 TF_RETURN_IF_GPUSPARSE_ERROR( 920 cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO)); 921 922 #elif TENSORFLOW_USE_ROCM 923 924 const gpusparseOperation_t transB = HIPSPARSE_OPERATION_TRANSPOSE; 925 926 gpusparseMatDescr_t descrA; 927 TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateMatDescr(&descrA)); 928 TF_RETURN_IF_GPUSPARSE_ERROR( 929 wrap::hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL)); 930 TF_RETURN_IF_GPUSPARSE_ERROR( 931 wrap::hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO)); 932 #endif // GOOGLE_CUDA 933 934 TF_RETURN_IF_ERROR( 935 cuda_sparse.Csrmm(transA, transB, m, n, k, nnz, &alpha, descrA, 936 a.values.data(), a.row_ptr.data(), a.col_ind.data(), 937 b.data(), ldb, &beta, c.data(), ldc)); 938 939 #endif // GOOGLE_CUDA && CUDA_VERSION >= 10020 940 } 941 942 return Status::OK(); 943 } 944 945 private: 946 bool transpose_output_; 947 }; 948 949 template <typename T> 950 class CSRSparseMatrixMatVec<GPUDevice, T> { 951 public: CSRSparseMatrixMatVec(bool transpose_a,bool conjugate_a)952 CSRSparseMatrixMatVec(bool transpose_a, bool conjugate_a) 953 : transA_(TransposeAndConjugateToGpuSparseOp(transpose_a, conjugate_a, 954 &status_)) {} 955 Compute(OpKernelContext * ctx,const ConstCSRComponent<T> & a,const T * x,T * y)956 Status Compute(OpKernelContext* ctx, const ConstCSRComponent<T>& a, 957 const T* x, T* y) { 958 TF_RETURN_IF_ERROR(status_); 959 GpuSparse cuda_sparse(ctx); 960 TF_RETURN_IF_ERROR(cuda_sparse.Initialize()); 961 { 962 // Use Csrmv to calculate: 963 // y = alpha * op(A) * x + beta * y 964 // where alpha = 1.0, beta = 0.0, A is a sparse matrix and x and y are 965 // dense vectors. 966 967 // Create alpha and beta scalars; alpha = 1.0, beta = 0.0 968 // TODO(rmlarsen,ebrevdo): Add support for general alpha, beta. 969 const T alpha = 1; 970 const T beta = 0; 971 972 #if GOOGLE_CUDA && CUDA_VERSION < 10020 973 gpusparseMatDescr_t descrA; 974 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA)); 975 TF_RETURN_IF_GPUSPARSE_ERROR( 976 cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL)); 977 TF_RETURN_IF_GPUSPARSE_ERROR( 978 cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO)); 979 #elif TENSORFLOW_USE_ROCM 980 gpusparseMatDescr_t descrA; 981 TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateMatDescr(&descrA)); 982 TF_RETURN_IF_GPUSPARSE_ERROR( 983 wrap::hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL)); 984 TF_RETURN_IF_GPUSPARSE_ERROR( 985 wrap::hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO)); 986 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 987 988 const int m = a.dense_shape_host(0); 989 const int n = a.dense_shape_host(1); 990 const int nnz = a.values.size(); 991 DCHECK_EQ(nnz, a.col_ind.size()); 992 #if GOOGLE_CUDA && (CUDA_VERSION >= 10020) 993 TF_RETURN_IF_ERROR(cuda_sparse.Csrmv(transA_, m, n, nnz, &alpha, 994 a.values.data(), a.row_ptr.data(), 995 a.col_ind.data(), x, &beta, y)); 996 #else 997 TF_RETURN_IF_ERROR(cuda_sparse.Csrmv(transA_, m, n, nnz, &alpha, descrA, 998 a.values.data(), a.row_ptr.data(), 999 a.col_ind.data(), x, &beta, y)); 1000 #endif 1001 } 1002 1003 return Status::OK(); 1004 } 1005 1006 private: 1007 Status status_; 1008 const gpusparseOperation_t transA_; 1009 }; 1010 1011 } // namespace functor 1012 1013 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 1014 1015 } // namespace tensorflow 1016