1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 19 #define EIGEN_USE_GPU 20 #endif 21 22 #include "third_party/eigen3/Eigen/Core" 23 #include "third_party/eigen3/Eigen/SparseCore" 24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 25 #include "tensorflow/core/framework/op.h" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/framework/tensor_types.h" 28 #include "tensorflow/core/framework/variant_op_registry.h" 29 #include "tensorflow/core/kernels/cwise_ops_common.h" 30 #include "tensorflow/core/kernels/dense_update_functor.h" 31 #include "tensorflow/core/kernels/fill_functor.h" 32 #include "tensorflow/core/kernels/sparse/kernels.h" 33 #include "tensorflow/core/kernels/sparse/sparse_matrix.h" 34 #include "tensorflow/core/kernels/sparse/transpose_op.h" 35 #include "tensorflow/core/kernels/transpose_functor.h" 36 #include "tensorflow/core/lib/gtl/inlined_vector.h" 37 #include "tensorflow/core/platform/threadpool.h" 38 39 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 40 #include "tensorflow/core/kernels/cuda_solvers.h" 41 #include "tensorflow/core/kernels/cuda_sparse.h" 42 #endif 43 44 namespace tensorflow { 45 46 // TODO(anudhyan): These constants may be tuned based on the performance of 47 // 'benchmark_sparse_matrix_mat_vec_mul'. We would like to find constants 48 // which work across hardware platforms for typical matrix sizes. It should be 49 // possible to observe at least 30-50% improvement as we increase the number 50 // of threads by 1. If not, then it may we worth increasing kMaxShards and 51 // kNumShardsPerThread. However, once we have too many shards, latency may be 52 // dominated by per-shard overhead. 53 // 54 // Maximum number of shards into which to divide the computation for each CSR 55 // Sparse Matrix instance. 56 static constexpr int32 kMaxShards = 20; 57 // Number of shards allocated to each thread. 58 static constexpr int32 kNumShardsPerThread = 3; 59 60 typedef Eigen::ThreadPoolDevice CPUDevice; 61 typedef Eigen::GpuDevice GPUDevice; 62 63 // Abstract OpKernel to compute sparse-dense matrix multiplication. 64 // 65 // Implements a kernel which, given a SparseMatrix `a` and dense Tensor `b`, 66 // computes a dense Tensor `c` satisfying `c = a * b` where * denotes matrix 67 // multiplication. 68 // 69 // The boolean attributes `transpose_a` and `adjoint_a` will transpose or 70 // adjoint `a` before multiplication, respectively. At most one of these 71 // attributes must be set to True. Corresponding attributes will transpose or 72 // adjoint `b` or the output (after multiplication). 73 // 74 // The rank of both `a` and `b` must be equal and their shapes must be 75 // compatible for matrix multiplication. Otherwise, InvalidArgument runtime 76 // errors will be thrown. Only rank 2 or rank 3 inputs are supported. 77 // 78 template <typename Device, typename T> 79 class CSRMatMulOp : public OpKernel { 80 public: CSRMatMulOp(OpKernelConstruction * c)81 explicit CSRMatMulOp(OpKernelConstruction* c) : OpKernel(c) { 82 OP_REQUIRES_OK(c, c->GetAttr("transpose_a", &transpose_a_)); 83 OP_REQUIRES_OK(c, c->GetAttr("transpose_b", &transpose_b_)); 84 bool adjoint_a; 85 OP_REQUIRES_OK(c, c->GetAttr("adjoint_a", &adjoint_a)); 86 OP_REQUIRES(c, !(adjoint_a && transpose_a_), 87 errors::InvalidArgument( 88 "Only one of adjoint_a and transpose_a may be true.")); 89 bool adjoint_b; 90 OP_REQUIRES_OK(c, c->GetAttr("adjoint_b", &adjoint_b)); 91 OP_REQUIRES(c, !(adjoint_b && transpose_b_), 92 errors::InvalidArgument( 93 "Only one of adjoint_b and transpose_b may be true.")); 94 OP_REQUIRES_OK(c, c->GetAttr("transpose_output", &transpose_output_)); 95 OP_REQUIRES_OK(c, c->GetAttr("conjugate_output", &conjugate_output_)); 96 conjugate_a_ = adjoint_a; 97 conjugate_b_ = adjoint_b; 98 transpose_a_ |= adjoint_a; 99 transpose_b_ |= adjoint_b; 100 } 101 ~CSRMatMulOp()102 ~CSRMatMulOp() override {} 103 ValidateInputs(const CSRSparseMatrix & sparse_matrix_a,const Tensor & dense_tensor_b,int * rank,int64 * batch_size)104 Status ValidateInputs(const CSRSparseMatrix& sparse_matrix_a, 105 const Tensor& dense_tensor_b, int* rank, 106 int64* batch_size) { 107 if (sparse_matrix_a.dtype() != dense_tensor_b.dtype()) { 108 return errors::InvalidArgument( 109 "Input types don't match. a.dtype == ", 110 DataTypeString(sparse_matrix_a.dtype()), 111 " vs. b.dtype == ", DataTypeString(dense_tensor_b.dtype())); 112 } 113 *rank = sparse_matrix_a.dims(); 114 // TODO(ebrevdo): Add support for broadcasting matmul. 115 if (*rank != dense_tensor_b.dims()) { 116 return errors::InvalidArgument("Ranks of a and b must match, saw: ", rank, 117 " vs. ", dense_tensor_b.dims(), "."); 118 } 119 // A valid CSR SparseMatrix has rank 2 or rank 3. 120 *batch_size = (*rank == 2) ? 1 : dense_tensor_b.dim_size(0); 121 if (sparse_matrix_a.batch_size() != *batch_size) { 122 return errors::InvalidArgument("Batch sizes of a and b must match, saw: ", 123 sparse_matrix_a.batch_size(), " vs. ", 124 batch_size, "."); 125 } 126 const auto& a_dense_shape = sparse_matrix_a.dense_shape().vec<int64>(); 127 const int64 a_inner_dim = 128 a_dense_shape(this->transpose_a_ ? *rank - 2 : *rank - 1); 129 const int64 b_inner_dim = 130 dense_tensor_b.dim_size(this->transpose_b_ ? *rank - 1 : *rank - 2); 131 if (a_inner_dim != b_inner_dim) { 132 return errors::InvalidArgument( 133 "Inner product dimensions of A and B do not agree. Shapes are: ", 134 TensorShape(a_dense_shape), " vs. ", 135 dense_tensor_b.shape().DebugString()); 136 } 137 return Status::OK(); 138 } 139 140 public: 141 bool transpose_a_; 142 bool transpose_b_; 143 bool conjugate_a_; 144 bool conjugate_b_; 145 bool transpose_output_; 146 bool conjugate_output_; 147 }; 148 149 // CPU Kernel to compute sparse-dense matrix multiplication. 150 // 151 // Uses Eigen SparseMatrix to compute the sparse-dense multiplication between 152 // a CSR SparseMatrix `a` and dense Tensor `b`. If intra-op parallelism is 153 // available, the implementation parallelizes the computation across each row 154 // of the sparse matrix. 155 template <typename T> 156 class CSRMatMulCPUOp : public CSRMatMulOp<CPUDevice, T> { 157 using SparseMatrix = Eigen::SparseMatrix<T, Eigen::RowMajor>; 158 using Matrix = 159 Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>; 160 using ConstMatrixMap = Eigen::Map<const Matrix>; 161 using MatrixMap = Eigen::Map<Matrix>; 162 163 public: CSRMatMulCPUOp(OpKernelConstruction * c)164 explicit CSRMatMulCPUOp(OpKernelConstruction* c) 165 : CSRMatMulOp<CPUDevice, T>(c) {} 166 ~CSRMatMulCPUOp()167 ~CSRMatMulCPUOp() override {} 168 Compute(OpKernelContext * ctx)169 void Compute(OpKernelContext* ctx) final { 170 const CSRSparseMatrix* sparse_matrix_a; 171 OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &sparse_matrix_a)); 172 const Tensor& matrix_b = ctx->input(1); 173 174 int rank; 175 int64 batch_size; 176 OP_REQUIRES_OK(ctx, this->ValidateInputs(*sparse_matrix_a, matrix_b, &rank, 177 &batch_size)); 178 179 const auto dense_shape = sparse_matrix_a->dense_shape().vec<int64>(); 180 int64 num_lhs_rows = dense_shape(rank - 2); 181 int64 num_lhs_cols = dense_shape(rank - 1); 182 int64 num_rhs_rows = matrix_b.dim_size(rank - 2); 183 int64 num_rhs_cols = matrix_b.dim_size(rank - 1); 184 185 if (this->transpose_a_) { 186 std::swap(num_lhs_rows, num_lhs_cols); 187 } 188 189 // Possibly transpose the dense Tensor b. 190 const Tensor* rhs = &matrix_b; 191 Tensor b_transposed; 192 if (this->transpose_b_) { 193 OP_REQUIRES_OK( 194 ctx, TransposeAndConjugateTensor(ctx, matrix_b, this->conjugate_b_, 195 &b_transposed)); 196 rhs = &b_transposed; 197 std::swap(num_rhs_rows, num_rhs_cols); 198 } 199 200 // If we're transposing the output, then allocate a temporary buffer to 201 // store the output. Otherwise allocate the output directly. 202 Tensor* output = nullptr; 203 Tensor* matmul_result = nullptr; 204 Tensor output_transposed; 205 OP_REQUIRES_OK( 206 ctx, AllocateOutput(ctx, rank, batch_size, num_lhs_rows, num_rhs_cols, 207 this->transpose_output_, &output, 208 &output_transposed, &matmul_result)); 209 210 if (!this->transpose_a_) { 211 SparseDenseMatMulWithoutTransposedLHS( 212 ctx, batch_size, num_lhs_rows, *sparse_matrix_a, *rhs, matmul_result); 213 } else { // transpose_a_ == true 214 SparseDenseMatMulWithTransposedLHS(ctx, batch_size, num_lhs_rows, 215 num_lhs_cols, *sparse_matrix_a, *rhs, 216 matmul_result); 217 } 218 219 // Transpose (and conjugate) the output if necessary. 220 // Note that conjugate is only true if transpose is also true. 221 if (this->transpose_output_) { 222 OP_REQUIRES_OK( 223 ctx, TransposeAndConjugateAllocatedTensor( 224 ctx, output_transposed, this->conjugate_output_, output)); 225 } else if (this->conjugate_output_) { 226 functor::maybe_conj_inplace<CPUDevice, T>::run( 227 ctx->eigen_device<CPUDevice>(), output); 228 } 229 } 230 231 private: 232 // Allocates the output with the appropriate shape. Additionally, if 233 // transpose_output is True, allocates a temporary buffer with the transposed 234 // output. 'matmul_result' points to either output or output_transposed, based 235 // on whether transpose_output is True. AllocateOutput(OpKernelContext * ctx,const int32 rank,const int64 batch_size,const int64 num_rows,const int64 num_cols,const bool transpose_output,Tensor ** output,Tensor * output_transposed,Tensor ** matmul_result)236 Status AllocateOutput(OpKernelContext* ctx, const int32 rank, 237 const int64 batch_size, const int64 num_rows, 238 const int64 num_cols, const bool transpose_output, 239 Tensor** output, Tensor* output_transposed, 240 Tensor** matmul_result) { 241 TensorShape output_shape; 242 if (rank == 3) output_shape.AddDim(batch_size); 243 244 if (!transpose_output) { 245 output_shape.AppendShape({num_rows, num_cols}); 246 TF_RETURN_IF_ERROR(ctx->allocate_output(0, output_shape, output)); 247 *matmul_result = *output; 248 } else { 249 TensorShape output_transposed_shape = output_shape; 250 output_transposed_shape.AppendShape({num_rows, num_cols}); 251 output_shape.AppendShape({num_cols, num_rows}); 252 TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value, 253 output_transposed_shape, 254 output_transposed)); 255 TF_RETURN_IF_ERROR(ctx->allocate_output(0, output_shape, output)); 256 *matmul_result = output_transposed; 257 } 258 return Status::OK(); 259 } 260 261 // Returns an Eigen::Ref expression of a sparse sub-matrix from the given 262 // contiguous segment of rows of the CSR Sparse Matrix. GetSparseMatrixRef(const CSRSparseMatrix & csr_matrix,const int batch_index,const int64 row_begin,const int64 num_shard_rows,std::vector<int32> * row_ptrs)263 Eigen::Ref<const SparseMatrix> GetSparseMatrixRef( 264 const CSRSparseMatrix& csr_matrix, const int batch_index, 265 const int64 row_begin, const int64 num_shard_rows, 266 std::vector<int32>* row_ptrs) { 267 // Compute the row pointers of the sparse sub-matrix. 268 row_ptrs->resize(num_shard_rows + 1); 269 const int64 row_offset = 270 csr_matrix.row_pointers_vec(batch_index)(row_begin); 271 for (int64 row_idx = 0; row_idx <= num_shard_rows; ++row_idx) { 272 row_ptrs->at(row_idx) = 273 csr_matrix.row_pointers_vec(batch_index)(row_begin + row_idx) - 274 row_offset; 275 } 276 const int64 num_cols = 277 csr_matrix.dense_shape().vec<int64>()(csr_matrix.dims() - 1); 278 return Eigen::Map<const SparseMatrix>( 279 num_shard_rows /* num_rows */, num_cols /* num_cols */, 280 row_ptrs->at(num_shard_rows) /* total_nnz */, row_ptrs->data(), 281 csr_matrix.col_indices_vec(batch_index).data() + row_offset, 282 csr_matrix.values_vec<T>(batch_index).data() + row_offset); 283 } 284 285 // Sparse-Dense Matrix Multiplication between a CSRSparseMatrix (LHS) and a 286 // dense Tensor (RHS). SparseDenseMatMulWithoutTransposedLHS(OpKernelContext * ctx,const int64 batch_size,const int64 num_lhs_rows,const CSRSparseMatrix & lhs,const Tensor & rhs,Tensor * output)287 void SparseDenseMatMulWithoutTransposedLHS( 288 OpKernelContext* ctx, const int64 batch_size, const int64 num_lhs_rows, 289 const CSRSparseMatrix& lhs, const Tensor& rhs, Tensor* output) { 290 // Parallelize matrix multiplication across batch dimensions and across 291 // rows in each batch. 292 auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); 293 const int32 num_threads = worker_threads.num_threads; 294 const int64 block_size = 295 num_lhs_rows / std::max(kMaxShards, kNumShardsPerThread * num_threads); 296 const int64 num_rhs_rows = rhs.dim_size(rhs.dims() - 2); 297 const int64 num_rhs_cols = rhs.dim_size(rhs.dims() - 1); 298 worker_threads.workers->ParallelFor( 299 batch_size * num_lhs_rows /* total */, 300 thread::ThreadPool::SchedulingParams( 301 thread::ThreadPool::SchedulingStrategy:: 302 kFixedBlockSize /* strategy */, 303 absl::nullopt /* cost_per_unit */, block_size), 304 [&](int64 batch_and_row_begin, int64 batch_and_row_end) { 305 HandleBatchAndRowRange( 306 num_lhs_rows, batch_and_row_begin, batch_and_row_end, 307 [&](int64 batch_idx, int64 row_begin, int64 row_end) { 308 const int64 num_shard_rows = row_end - row_begin; 309 310 // Define an Eigen::SparseMatrix over the row range: 311 // [row_begin, row_end) of the CSR SparseMatrix A. 312 std::vector<int32> row_ptrs; 313 auto sparse_matrix = GetSparseMatrixRef( 314 lhs, batch_idx, row_begin, num_shard_rows, &row_ptrs); 315 316 // Map the corresponding rows of the rhs. 317 ConstMatrixMap rhs_map(rhs.flat<T>().data() + batch_idx * 318 num_rhs_rows * 319 num_rhs_cols, 320 num_rhs_rows, num_rhs_cols); 321 322 // Write to the corresponding rows of the output matrix. 323 MatrixMap output_map( 324 output->flat<T>().data() + 325 batch_idx * num_lhs_rows * num_rhs_cols + 326 row_begin * num_rhs_cols, 327 num_shard_rows, num_rhs_cols); 328 output_map.noalias() = sparse_matrix * rhs_map; 329 }); 330 }); 331 } 332 333 // Sparse-Dense Matrix Multiplication assuming the CSRSparseMatrix (LHS) is 334 // to be transposed before the operation. SparseDenseMatMulWithTransposedLHS(OpKernelContext * ctx,const int64 batch_size,const int64 num_lhs_rows,const int64 num_lhs_cols,const CSRSparseMatrix & lhs,const Tensor & rhs,Tensor * output)335 void SparseDenseMatMulWithTransposedLHS(OpKernelContext* ctx, 336 const int64 batch_size, 337 const int64 num_lhs_rows, 338 const int64 num_lhs_cols, 339 const CSRSparseMatrix& lhs, 340 const Tensor& rhs, Tensor* output) { 341 auto device = ctx->eigen_device<CPUDevice>(); 342 auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); 343 const int32 num_threads = worker_threads.num_threads; 344 const int64 num_rhs_rows = rhs.dim_size(rhs.dims() - 2); 345 const int64 num_rhs_cols = rhs.dim_size(rhs.dims() - 1); 346 // Usually, we want to avoid transposing the sparse matrix A since it may be 347 // an expensive operation. Instead, we use the identity (A^T B) = (B^T A)^T. 348 // We don't actually transpose B or the output because it is more convenient 349 // to have them in column major form. 350 // 351 // However, if A is hypersparse and B and C are huge, transposing A will be 352 // cheaper. In the future, we should have a cost model estimating the cost 353 // of transposing all matrices (A, B, C) to decide which variant to use. 354 355 // Each thread writes to its own copy of the matrix product. These 356 // `num_threads` copies are summed together to obtain the final result. 357 Tensor matmul_result_buffer; 358 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, 359 TensorShape({num_threads + 1, 360 output->NumElements()}), 361 &matmul_result_buffer)); 362 functor::SetZeroFunctor<CPUDevice, T> set_zero; 363 set_zero(device, matmul_result_buffer.flat<T>()); 364 365 // Parallelize matrix multiplication across batch dimensions and across 366 // columns of A^T in each batch. These correspond to rows of A. 367 const int64 block_size = 368 num_lhs_cols / std::max(kMaxShards, kNumShardsPerThread * num_threads); 369 worker_threads.workers->ParallelForWithWorkerId( 370 batch_size * num_lhs_cols /* total */, 371 thread::ThreadPool::SchedulingParams( 372 thread::ThreadPool::SchedulingStrategy:: 373 kFixedBlockSize /* strategy */, 374 absl::nullopt /* cost_per_unit */, block_size), 375 [&](int64 batch_and_row_begin, int64 batch_and_row_end, int tid) { 376 HandleBatchAndRowRange( 377 num_lhs_cols, batch_and_row_begin, batch_and_row_end, 378 [&](int64 batch_idx, int64 row_begin, int64 row_end) { 379 const int64 num_shard_rows = row_end - row_begin; 380 381 // Define a new sparse sub-matrix from the row range 382 // [row_begin, row_end) of the sparse matrix A. 383 std::vector<int32> row_ptrs; 384 auto sparse_matrix = GetSparseMatrixRef( 385 lhs, batch_idx, row_begin, num_shard_rows, &row_ptrs); 386 387 // Map the corresponding `num_shard_rows` columns of B^T. 388 // This is the same as taking the `num_shard_rows` rows of B. 389 ConstMatrixMap b_dense_map( 390 rhs.flat<T>().data() + 391 batch_idx * num_rhs_rows * num_rhs_cols + 392 row_begin * num_rhs_cols, 393 num_shard_rows, num_rhs_cols); 394 395 // Map to the corresponding rows of the output. 396 MatrixMap output_map( 397 matmul_result_buffer.flat<T>().data() + 398 tid * batch_size * num_lhs_rows * num_rhs_cols + 399 batch_idx * num_lhs_rows * num_rhs_cols, 400 num_lhs_rows, num_rhs_cols); 401 402 // Compute the product C^T = B^T * A; restricted to the row 403 // range in the current shard. 404 if (this->conjugate_a_) { 405 output_map.transpose().noalias() += 406 b_dense_map.transpose() * sparse_matrix.conjugate(); 407 } else { 408 output_map.transpose().noalias() += 409 b_dense_map.transpose() * sparse_matrix; 410 } 411 }); 412 }); 413 414 // Sum across each thread's matmul result. 415 using Reducer = Eigen::internal::SumReducer<T>; 416 using Index = typename TTypes<T>::Tensor::Index; 417 output->flat<T>().device(device) = matmul_result_buffer.matrix<T>().reduce( 418 Eigen::array<Index, 1>({0}), Reducer()); 419 } 420 421 // Given a range [batch_and_row_begin, batch_and_row_end) which is a 422 // contiguous subset of [0, num_rows * batch_size), calls the function 423 // fn(batch_idx, row_begin, row_end) for each batch index 424 // and the row range [row_begin, row_end) contained in the batch. HandleBatchAndRowRange(const int64 num_rows,const int64 batch_and_row_begin,const int64 batch_and_row_end,const std::function<void (int64,int64,int64)> & fn)425 void HandleBatchAndRowRange( 426 const int64 num_rows, const int64 batch_and_row_begin, 427 const int64 batch_and_row_end, 428 const std::function<void(int64, int64, int64)>& fn) { 429 // Obtain the batch indices overlapping with the current shard. 430 const int64 batch_begin = batch_and_row_begin / num_rows; 431 const int64 batch_end_inclusive = batch_and_row_end / num_rows; 432 433 for (int64 batch_idx = batch_begin; batch_idx <= batch_end_inclusive; 434 ++batch_idx) { 435 // Find the contiguous set of rows which are contained in this shard as 436 // well as the current batch. We intersect with interval [batch_idx * 437 // num_rows, (batch_idx + 1) * num_rows) which denotes the current batch. 438 const int64 current_batch_row_begin = 439 std::max(batch_and_row_begin, batch_idx * num_rows); 440 const int64 current_batch_row_end = 441 std::min(batch_and_row_end, (batch_idx + 1) * num_rows); 442 443 const int64 row_begin = current_batch_row_begin % num_rows; 444 const int64 num_shard_rows = 445 current_batch_row_end - current_batch_row_begin; 446 // Edge case for when current_batch_row_end is the first index of a new 447 // row. 448 if (num_shard_rows == 0) continue; 449 450 fn(batch_idx, row_begin, row_begin + num_shard_rows); 451 } 452 } 453 454 // Transposes (and optionally, conjugates) a given Tensor. Also allocates the 455 // required memory for the output Tensor. TransposeAndConjugateTensor(OpKernelContext * ctx,const Tensor & input,bool conjugate,Tensor * output)456 Status TransposeAndConjugateTensor(OpKernelContext* ctx, const Tensor& input, 457 bool conjugate, Tensor* output) { 458 TensorShape transposed_shape = input.shape(); 459 transposed_shape.set_dim(input.dims() - 1, 460 input.dim_size(input.dims() - 2)); 461 transposed_shape.set_dim(input.dims() - 2, 462 input.dim_size(input.dims() - 1)); 463 TF_RETURN_IF_ERROR( 464 ctx->allocate_temp(DataTypeToEnum<T>::value, transposed_shape, output)); 465 return TransposeAndConjugateAllocatedTensor(ctx, input, conjugate, output); 466 } 467 468 // Transposes (and optionally, conjugates) a given Tensor. The output should 469 // be already allocated. TransposeAndConjugateAllocatedTensor(OpKernelContext * ctx,const Tensor & input,bool conjugate,Tensor * output)470 Status TransposeAndConjugateAllocatedTensor(OpKernelContext* ctx, 471 const Tensor& input, 472 bool conjugate, Tensor* output) { 473 if (conjugate) { 474 TF_RETURN_IF_ERROR(DoConjugateMatrixTranspose( 475 ctx->eigen_device<CPUDevice>(), input, output)); 476 } else { 477 TF_RETURN_IF_ERROR( 478 DoMatrixTranspose(ctx->eigen_device<CPUDevice>(), input, output)); 479 } 480 return Status::OK(); 481 } 482 }; 483 484 // GPU Kernel to compute sparse-dense matrix multiplication. 485 template <typename T> 486 class CSRMatMulGPUOp : public CSRMatMulOp<GPUDevice, T> { 487 using SparseMatrix = Eigen::SparseMatrix<T, Eigen::RowMajor>; 488 using Matrix = 489 Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>; 490 using ConstMatrixMap = Eigen::Map<const Matrix>; 491 using MatrixMap = Eigen::Map<Matrix>; 492 493 public: CSRMatMulGPUOp(OpKernelConstruction * c)494 explicit CSRMatMulGPUOp(OpKernelConstruction* c) 495 : CSRMatMulOp<GPUDevice, T>(c) {} 496 ~CSRMatMulGPUOp()497 ~CSRMatMulGPUOp() override {} 498 Compute(OpKernelContext * ctx)499 void Compute(OpKernelContext* ctx) final { 500 const CSRSparseMatrix* a_matrix; 501 OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &a_matrix)); 502 const Tensor& b_t = ctx->input(1); 503 504 int rank; 505 int64 batch_size; 506 OP_REQUIRES_OK(ctx, 507 this->ValidateInputs(*a_matrix, b_t, &rank, &batch_size)); 508 509 const Tensor& a_dense_shape_t = a_matrix->dense_shape(); 510 TensorShape a_dense_tensor_shape; 511 auto a_dense_shape = a_dense_shape_t.vec<int64>(); 512 OP_REQUIRES_OK( 513 ctx, TensorShapeUtils::MakeShape(a_dense_shape, &a_dense_tensor_shape)); 514 515 const int row_dim = (rank == 2) ? 0 : 1; 516 const int64 a_outer_dim = a_dense_tensor_shape.dim_size( 517 this->transpose_a_ ? row_dim + 1 : row_dim); 518 const int64 b_inner_dim = 519 b_t.shape().dim_size(this->transpose_b_ ? row_dim + 1 : row_dim); 520 const int64 b_outer_dim = 521 b_t.dim_size(this->transpose_b_ ? row_dim : row_dim + 1); 522 const int64 b_slice_size = b_inner_dim * b_outer_dim; 523 524 TensorShape c_shape; 525 if (rank == 3) c_shape.AddDim(batch_size); 526 if (this->transpose_output_) { 527 c_shape.AddDim(b_outer_dim); 528 c_shape.AddDim(a_outer_dim); 529 } else { 530 c_shape.AddDim(a_outer_dim); 531 c_shape.AddDim(b_outer_dim); 532 } 533 534 const int64 c_matrix_lhs = c_shape.dim_size(row_dim); 535 const int64 c_matrix_rhs = c_shape.dim_size(row_dim + 1); 536 const int64 c_slice_size = c_matrix_lhs * c_matrix_rhs; 537 Tensor* c_t; 538 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, c_shape, &c_t)); 539 540 const GPUDevice& d = ctx->eigen_device<GPUDevice>(); 541 542 if (b_outer_dim == 1) { 543 // Call matrix-vector multiply if b is a vector. 544 TTypes<int64>::ConstVec a_dense_shape_comp(a_dense_shape.data() + row_dim, 545 2); 546 Tensor b_conj_t; 547 const T* b_base_ptr = b_t.template flat<T>().data(); 548 bool conjugate_a = this->conjugate_a_; 549 bool conjugate_output = this->conjugate_output_; 550 if (this->conjugate_b_) { 551 if (conjugate_a) { 552 // In this case we can use the identity 553 // conj(a) * conj(b) = conj(a * b) 554 // instead of creating a conjugated copy of b. 555 conjugate_a = false; 556 conjugate_output = !conjugate_output; 557 } else { 558 OP_REQUIRES_OK( 559 ctx, ctx->forward_input_or_allocate_temp( 560 {1}, DataTypeToEnum<T>::value, b_t.shape(), &b_conj_t)); 561 functor::maybe_conj<GPUDevice, T>::run(d, b_t, &b_conj_t); 562 b_base_ptr = b_conj_t.template flat<T>().data(); 563 } 564 } 565 566 functor::CSRSparseMatrixMatVec<GPUDevice, T> csr_spmv(this->transpose_a_, 567 conjugate_a); 568 for (int i = 0; i < batch_size; ++i) { 569 auto a_row_ptr = a_matrix->row_pointers_vec(i); 570 auto a_col_ind = a_matrix->col_indices_vec(i); 571 auto a_values = a_matrix->values_vec<T>(i); 572 ConstCSRComponent<T> a_comp{a_row_ptr, a_col_ind, a_values, 573 a_dense_shape_comp}; 574 const T* b_i = b_base_ptr + i * b_slice_size; 575 T* c_i = &c_t->template flat<T>()(i * c_slice_size); 576 Status s = csr_spmv.Compute(ctx, a_comp, b_i, c_i); 577 OP_REQUIRES_OK(ctx, s); 578 } 579 if (conjugate_output) { 580 functor::maybe_conj_inplace<GPUDevice, T>::run(d, c_t); 581 } 582 return; 583 } 584 585 functor::CSRSparseMatrixMatMul<GPUDevice, T> csr_spmmadd( 586 this->transpose_output_); 587 588 Tensor c_mat_col_major_t; 589 if (!this->transpose_output_) { 590 // If transpose_output is false, we'll need to transpose the (col 591 // major) output of the csrgemm call to get proper (row-major) 592 // output. Which means we need to keep a temporary buffer to 593 // store the intermediate gemm output. 594 TensorShape c_mat_col_major_shape; 595 if (rank == 2) { 596 c_mat_col_major_shape = TensorShape({c_matrix_rhs, c_matrix_lhs}); 597 } else { 598 c_mat_col_major_shape = 599 TensorShape({batch_size, c_matrix_rhs, c_matrix_lhs}); 600 } 601 OP_REQUIRES_OK( 602 ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, 603 c_mat_col_major_shape, &c_mat_col_major_t)); 604 } 605 606 // If transpose_output is true, return the direct (column-major i.e., 607 // transposed) output of the csrgemm call. Otherwise we'll need 608 // to transpose it to row major format. 609 auto c_mat_col_major = (this->transpose_output_) 610 ? c_t->flat<T>() 611 : c_mat_col_major_t.flat<T>(); 612 613 // Possibly transpose a. 614 const CSRSparseMatrix* a_input_matrix; 615 // If we need to transpose a, we will store the result temporarily 616 // in the object below. 617 CSRSparseMatrix a_matrix_transposed; 618 if (!this->transpose_a_) { 619 a_input_matrix = a_matrix; 620 } else { 621 functor::CSRSparseMatrixTranspose<GPUDevice, T> transpose; 622 OP_REQUIRES_OK(ctx, transpose(ctx, this->conjugate_a_, *a_matrix, 623 &a_matrix_transposed)); 624 a_input_matrix = &a_matrix_transposed; 625 } 626 627 auto a_input_dense_shape = a_input_matrix->dense_shape().vec<int64>(); 628 629 // Possibly transpose b. 630 Tensor b_t_input; 631 if (!this->transpose_b_) { 632 b_t_input = b_t; 633 } else { 634 TensorShape b_t_transposed_shape; 635 if (rank == 3) { 636 b_t_transposed_shape.AddDim(batch_size); 637 } 638 b_t_transposed_shape.AddDim(b_t.dim_size(row_dim + 1)); 639 b_t_transposed_shape.AddDim(b_t.dim_size(row_dim)); 640 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, 641 b_t_transposed_shape, &b_t_input)); 642 const GPUDevice& d = ctx->eigen_device<GPUDevice>(); 643 if (this->conjugate_b_) { 644 OP_REQUIRES_OK(ctx, DoConjugateMatrixTranspose(d, b_t /*input*/, 645 &b_t_input /*output*/)); 646 } else { 647 OP_REQUIRES_OK( 648 ctx, DoMatrixTranspose(d, b_t /*input*/, &b_t_input /*output*/)); 649 } 650 } 651 652 // Dense shape of a batch component of A. 653 TTypes<int64>::ConstVec a_input_dense_shape_comp( 654 a_input_dense_shape.data() + row_dim, 2); 655 656 auto b = b_t_input.flat<T>(); 657 658 for (int i = 0; i < batch_size; ++i) { 659 auto a_row_ptr = a_input_matrix->row_pointers_vec(i); 660 auto a_col_ind = a_input_matrix->col_indices_vec(i); 661 auto a_values = a_input_matrix->values_vec<T>(i); 662 typename TTypes<T>::UnalignedConstMatrix b_i(b.data() + i * b_slice_size, 663 {b_inner_dim, b_outer_dim}); 664 typename TTypes<T>::UnalignedMatrix c_mat_col_major_i( 665 c_mat_col_major.data() + i * c_slice_size, 666 {c_matrix_lhs, c_matrix_rhs}); 667 ConstCSRComponent<T> a_comp{a_row_ptr, a_col_ind, a_values, 668 a_input_dense_shape_comp}; 669 Status s = csr_spmmadd.Compute(ctx, a_comp, b_i, c_mat_col_major_i); 670 OP_REQUIRES_OK(ctx, s); 671 } 672 673 if (!this->transpose_output_) { 674 // We need to return values in row major format, so transpose 675 // the column-major values in c_mat_col_major_t to row-major output c_t. 676 OP_REQUIRES_OK(ctx, DoMatrixTranspose(d, /*input=*/c_mat_col_major_t, 677 /*output=*/c_t)); 678 } 679 if (this->conjugate_output_) { 680 functor::maybe_conj_inplace<GPUDevice, T>::run(d, c_t); 681 } 682 } 683 }; 684 685 #define REGISTER_CPU(T) \ 686 REGISTER_KERNEL_BUILDER( \ 687 Name("SparseMatrixMatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 688 CSRMatMulCPUOp<T>); 689 690 REGISTER_CPU(float) 691 REGISTER_CPU(double) 692 REGISTER_CPU(complex64) 693 REGISTER_CPU(complex128) 694 695 #undef REGISTER_CPU 696 697 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 698 699 #define REGISTER_GPU(T) \ 700 REGISTER_KERNEL_BUILDER( \ 701 Name("SparseMatrixMatMul").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ 702 CSRMatMulGPUOp<T>); 703 704 REGISTER_GPU(float) 705 REGISTER_GPU(double) 706 #if GOOGLE_CUDA 707 REGISTER_GPU(complex64) 708 REGISTER_GPU(complex128) 709 #endif 710 711 #undef REGISTER_GPU 712 713 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 714 715 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 716 717 namespace functor { 718 719 template <typename T> 720 class CSRSparseMatrixMatMul<GPUDevice, T> { 721 public: CSRSparseMatrixMatMul(const bool transpose_output)722 explicit CSRSparseMatrixMatMul(const bool transpose_output) 723 : transpose_output_(transpose_output) {} 724 Compute(OpKernelContext * ctx,const ConstCSRComponent<T> & a,typename TTypes<T>::UnalignedConstMatrix b,typename TTypes<T>::UnalignedMatrix c)725 Status Compute(OpKernelContext* ctx, const ConstCSRComponent<T>& a, 726 typename TTypes<T>::UnalignedConstMatrix b, 727 typename TTypes<T>::UnalignedMatrix c) { 728 GpuSparse cuda_sparse(ctx); 729 TF_RETURN_IF_ERROR(cuda_sparse.Initialize()); 730 { 731 // Use Csrmm to calculate: 732 // C = alpha * op(A) * op(B) + beta * C 733 // where alpha = 1.0, beta = 0.0, A is sparse and B and C are dense. 734 // Note that Csrmm assumes B and C are in column-major form; so we 735 // use transB == true, and manually transpose the output in place 736 // using blas<t>geam. 737 // TODO(ebrevdo,rmlarsen): Add support for transposition and adjoint. 738 739 // Create alpha and beta scalars; alpha = 1.0, beta = 0.0 740 // TODO(ebrevdo,rmlarsen): Add support for non-trivial alpha and beta. 741 const T alpha = 1; 742 const T beta = 0; 743 744 // transA must be non-transpose if transB is transpose (cusparse 745 // limitation). 746 #if GOOGLE_CUDA 747 const gpusparseOperation_t transA = CUSPARSE_OPERATION_NON_TRANSPOSE; 748 #elif TENSORFLOW_USE_ROCM 749 const gpusparseOperation_t transA = HIPSPARSE_OPERATION_NON_TRANSPOSE; 750 #endif 751 752 // transB: b is row-major, and cusparse requires col-major b (or 753 // equivalently transB == transpose). this version is actually more 754 // efficient. 755 #if GOOGLE_CUDA 756 const gpusparseOperation_t transB = CUSPARSE_OPERATION_TRANSPOSE; 757 758 gpusparseMatDescr_t descrA; 759 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA)); 760 TF_RETURN_IF_GPUSPARSE_ERROR( 761 cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL)); 762 TF_RETURN_IF_GPUSPARSE_ERROR( 763 cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO)); 764 #elif TENSORFLOW_USE_ROCM 765 const gpusparseOperation_t transB = HIPSPARSE_OPERATION_TRANSPOSE; 766 767 gpusparseMatDescr_t descrA; 768 TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descrA)); 769 TF_RETURN_IF_GPUSPARSE_ERROR( 770 hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL)); 771 TF_RETURN_IF_GPUSPARSE_ERROR( 772 hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO)); 773 #endif 774 775 // A is (m, k), Bt is (ldb, k) and Ct is (ldc, n) 776 const int k = b.dimension(0); 777 DCHECK_EQ(k, a.dense_shape_host(1)); 778 779 // If transpose_output_ is true, then the c matrix we receive 780 // here is the direct row major output (into which we will store 781 // csrgemm's col major output). Otherwise it's a 782 // temporary tensor that will store the column major output that 783 // will eventually be transposed. 784 const int m = c.dimension(transpose_output_ ? 1 : 0); 785 const int n = c.dimension(transpose_output_ ? 0 : 1); 786 DCHECK_EQ(m, a.dense_shape_host(0)); 787 DCHECK_EQ(n, b.dimension(1)); 788 const int nnz = a.values.size(); 789 DCHECK_EQ(nnz, a.col_ind.size()); 790 791 // ldb: leading dimension of B. If op(B)=B, it must be at least max(1, k) 792 // if op(A) = A and at least max (1, m) otherwise. If op(B) != B, it must 793 // be at least max(1, n). 794 const int ldb = n; 795 // ldc: leading dimension of C. It must be at least max(1, m) if 796 // op(A) = A and at least max(1, k) otherwise. 797 const int ldc = m; 798 799 TF_RETURN_IF_ERROR( 800 cuda_sparse.Csrmm(transA, transB, m, n, k, nnz, &alpha, descrA, 801 a.values.data(), a.row_ptr.data(), a.col_ind.data(), 802 b.data(), ldb, &beta, c.data(), ldc)); 803 } 804 805 return Status::OK(); 806 } 807 808 private: 809 bool transpose_output_; 810 }; 811 812 template <typename T> 813 class CSRSparseMatrixMatVec<GPUDevice, T> { 814 public: CSRSparseMatrixMatVec(bool transpose_a,bool conjugate_a)815 CSRSparseMatrixMatVec(bool transpose_a, bool conjugate_a) 816 : transA_(TransposeAndConjugateToGpuSparseOp(transpose_a, conjugate_a, 817 &status_)) {} 818 Compute(OpKernelContext * ctx,const ConstCSRComponent<T> & a,const T * x,T * y)819 Status Compute(OpKernelContext* ctx, const ConstCSRComponent<T>& a, 820 const T* x, T* y) { 821 TF_RETURN_IF_ERROR(status_); 822 GpuSparse cuda_sparse(ctx); 823 TF_RETURN_IF_ERROR(cuda_sparse.Initialize()); 824 { 825 // Use Csrmv to calculate: 826 // y = alpha * op(A) * x + beta * y 827 // where alpha = 1.0, beta = 0.0, A is a sparse matrix and x and y are 828 // dense vectors. 829 830 // Create alpha and beta scalars; alpha = 1.0, beta = 0.0 831 // TODO(rmlarsen,ebrevdo): Add support for general alpha, beta. 832 const T alpha = 1; 833 const T beta = 0; 834 835 gpusparseMatDescr_t descrA; 836 #if GOOGLE_CUDA 837 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA)); 838 TF_RETURN_IF_GPUSPARSE_ERROR( 839 cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL)); 840 TF_RETURN_IF_GPUSPARSE_ERROR( 841 cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO)); 842 #elif TENSORFLOW_USE_ROCM 843 TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descrA)); 844 TF_RETURN_IF_GPUSPARSE_ERROR( 845 hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL)); 846 TF_RETURN_IF_GPUSPARSE_ERROR( 847 hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO)); 848 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 849 850 const int m = a.dense_shape_host(0); 851 const int n = a.dense_shape_host(1); 852 const int nnz = a.values.size(); 853 DCHECK_EQ(nnz, a.col_ind.size()); 854 TF_RETURN_IF_ERROR(cuda_sparse.Csrmv(transA_, m, n, nnz, &alpha, descrA, 855 a.values.data(), a.row_ptr.data(), 856 a.col_ind.data(), x, &beta, y)); 857 } 858 859 return Status::OK(); 860 } 861 862 private: 863 Status status_; 864 const gpusparseOperation_t transA_; 865 }; 866 867 } // namespace functor 868 869 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 870 871 } // namespace tensorflow 872