1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 19 #define EIGEN_USE_GPU 20 #endif 21 22 #include "third_party/eigen3/Eigen/Core" 23 #include "third_party/eigen3/Eigen/SparseCore" 24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 25 #include "tensorflow/core/framework/op.h" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/framework/tensor_types.h" 28 #include "tensorflow/core/framework/variant_op_registry.h" 29 #include "tensorflow/core/kernels/cwise_ops_common.h" 30 #include "tensorflow/core/kernels/dense_update_functor.h" 31 #include "tensorflow/core/kernels/fill_functor.h" 32 #include "tensorflow/core/kernels/sparse/kernels.h" 33 #include "tensorflow/core/kernels/sparse/sparse_matrix.h" 34 #include "tensorflow/core/kernels/sparse/transpose_op.h" 35 #include "tensorflow/core/kernels/transpose_functor.h" 36 #include "tensorflow/core/lib/gtl/inlined_vector.h" 37 #include "tensorflow/core/platform/threadpool.h" 38 39 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 40 #include "tensorflow/core/util/cuda_solvers.h" 41 #include "tensorflow/core/util/cuda_sparse.h" 42 #endif 43 44 namespace tensorflow { 45 46 // TODO(anudhyan): These constants may be tuned based on the performance of 47 // 'benchmark_sparse_matrix_mat_vec_mul'. We would like to find constants 48 // which work across hardware platforms for typical matrix sizes. It should be 49 // possible to observe at least 30-50% improvement as we increase the number 50 // of threads by 1. If not, then it may we worth increasing kMaxShards and 51 // kNumShardsPerThread. However, once we have too many shards, latency may be 52 // dominated by per-shard overhead. 53 // 54 // Maximum number of shards into which to divide the computation for each CSR 55 // Sparse Matrix instance. 56 static constexpr int32 kMaxShards = 20; 57 // Number of shards allocated to each thread. 58 static constexpr int32 kNumShardsPerThread = 3; 59 60 typedef Eigen::ThreadPoolDevice CPUDevice; 61 typedef Eigen::GpuDevice GPUDevice; 62 63 // Abstract OpKernel to compute sparse-dense matrix multiplication. 64 // 65 // Implements a kernel which, given a SparseMatrix `a` and dense Tensor `b`, 66 // computes a dense Tensor `c` satisfying `c = a * b` where * denotes matrix 67 // multiplication. 68 // 69 // The boolean attributes `transpose_a` and `adjoint_a` will transpose or 70 // adjoint `a` before multiplication, respectively. At most one of these 71 // attributes must be set to True. Corresponding attributes will transpose or 72 // adjoint `b` or the output (after multiplication). 73 // 74 // The rank of both `a` and `b` must be equal and their shapes must be 75 // compatible for matrix multiplication. Otherwise, InvalidArgument runtime 76 // errors will be thrown. Only rank 2 or rank 3 inputs are supported. 77 // 78 template <typename Device, typename T> 79 class CSRMatMulOp : public OpKernel { 80 public: CSRMatMulOp(OpKernelConstruction * c)81 explicit CSRMatMulOp(OpKernelConstruction* c) : OpKernel(c) { 82 OP_REQUIRES_OK(c, c->GetAttr("transpose_a", &transpose_a_)); 83 OP_REQUIRES_OK(c, c->GetAttr("transpose_b", &transpose_b_)); 84 bool adjoint_a; 85 OP_REQUIRES_OK(c, c->GetAttr("adjoint_a", &adjoint_a)); 86 OP_REQUIRES(c, !(adjoint_a && transpose_a_), 87 errors::InvalidArgument( 88 "Only one of adjoint_a and transpose_a may be true.")); 89 bool adjoint_b; 90 OP_REQUIRES_OK(c, c->GetAttr("adjoint_b", &adjoint_b)); 91 OP_REQUIRES(c, !(adjoint_b && transpose_b_), 92 errors::InvalidArgument( 93 "Only one of adjoint_b and transpose_b may be true.")); 94 OP_REQUIRES_OK(c, c->GetAttr("transpose_output", &transpose_output_)); 95 OP_REQUIRES_OK(c, c->GetAttr("conjugate_output", &conjugate_output_)); 96 conjugate_a_ = adjoint_a; 97 conjugate_b_ = adjoint_b; 98 transpose_a_ |= adjoint_a; 99 transpose_b_ |= adjoint_b; 100 } 101 ~CSRMatMulOp()102 ~CSRMatMulOp() override {} 103 ValidateInputs(const CSRSparseMatrix & sparse_matrix_a,const Tensor & dense_tensor_b,int * rank,int64 * batch_size)104 Status ValidateInputs(const CSRSparseMatrix& sparse_matrix_a, 105 const Tensor& dense_tensor_b, int* rank, 106 int64* batch_size) { 107 if (sparse_matrix_a.dtype() != dense_tensor_b.dtype()) { 108 return errors::InvalidArgument( 109 "Input types don't match. a.dtype == ", 110 DataTypeString(sparse_matrix_a.dtype()), 111 " vs. b.dtype == ", DataTypeString(dense_tensor_b.dtype())); 112 } 113 *rank = sparse_matrix_a.dims(); 114 // TODO(ebrevdo): Add support for broadcasting matmul. 115 if (*rank != dense_tensor_b.dims()) { 116 return errors::InvalidArgument("Ranks of a and b must match, saw: ", rank, 117 " vs. ", dense_tensor_b.dims(), "."); 118 } 119 // A valid CSR SparseMatrix has rank 2 or rank 3. 120 *batch_size = (*rank == 2) ? 1 : dense_tensor_b.dim_size(0); 121 if (sparse_matrix_a.batch_size() != *batch_size) { 122 return errors::InvalidArgument("Batch sizes of a and b must match, saw: ", 123 sparse_matrix_a.batch_size(), " vs. ", 124 batch_size, "."); 125 } 126 const auto& a_dense_shape = sparse_matrix_a.dense_shape().vec<int64>(); 127 const int64 a_inner_dim = 128 a_dense_shape(this->transpose_a_ ? *rank - 2 : *rank - 1); 129 const int64 b_inner_dim = 130 dense_tensor_b.dim_size(this->transpose_b_ ? *rank - 1 : *rank - 2); 131 if (a_inner_dim != b_inner_dim) { 132 return errors::InvalidArgument( 133 "Inner product dimensions of A and B do not agree. Shapes are: ", 134 TensorShape(a_dense_shape), " vs. ", 135 dense_tensor_b.shape().DebugString()); 136 } 137 return Status::OK(); 138 } 139 140 public: 141 bool transpose_a_; 142 bool transpose_b_; 143 bool conjugate_a_; 144 bool conjugate_b_; 145 bool transpose_output_; 146 bool conjugate_output_; 147 }; 148 149 // CPU Kernel to compute sparse-dense matrix multiplication. 150 // 151 // Uses Eigen SparseMatrix to compute the sparse-dense multiplication between 152 // a CSR SparseMatrix `a` and dense Tensor `b`. If intra-op parallelism is 153 // available, the implementation parallelizes the computation across each row 154 // of the sparse matrix. 155 template <typename T> 156 class CSRMatMulCPUOp : public CSRMatMulOp<CPUDevice, T> { 157 using SparseMatrix = Eigen::SparseMatrix<T, Eigen::RowMajor>; 158 using Matrix = 159 Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>; 160 using ConstMatrixMap = Eigen::Map<const Matrix>; 161 using MatrixMap = Eigen::Map<Matrix>; 162 163 public: CSRMatMulCPUOp(OpKernelConstruction * c)164 explicit CSRMatMulCPUOp(OpKernelConstruction* c) 165 : CSRMatMulOp<CPUDevice, T>(c) {} 166 ~CSRMatMulCPUOp()167 ~CSRMatMulCPUOp() override {} 168 Compute(OpKernelContext * ctx)169 void Compute(OpKernelContext* ctx) final { 170 const CSRSparseMatrix* sparse_matrix_a; 171 OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &sparse_matrix_a)); 172 const Tensor& matrix_b = ctx->input(1); 173 174 int rank; 175 int64 batch_size; 176 OP_REQUIRES_OK(ctx, this->ValidateInputs(*sparse_matrix_a, matrix_b, &rank, 177 &batch_size)); 178 179 const auto dense_shape = sparse_matrix_a->dense_shape().vec<int64>(); 180 int64 num_lhs_rows = dense_shape(rank - 2); 181 int64 num_lhs_cols = dense_shape(rank - 1); 182 int64 num_rhs_rows = matrix_b.dim_size(rank - 2); 183 int64 num_rhs_cols = matrix_b.dim_size(rank - 1); 184 185 if (this->transpose_a_) { 186 std::swap(num_lhs_rows, num_lhs_cols); 187 } 188 189 // Possibly transpose the dense Tensor b. 190 const Tensor* rhs = &matrix_b; 191 Tensor b_transposed; 192 if (this->transpose_b_) { 193 OP_REQUIRES_OK( 194 ctx, TransposeAndConjugateTensor(ctx, matrix_b, this->conjugate_b_, 195 &b_transposed)); 196 rhs = &b_transposed; 197 std::swap(num_rhs_rows, num_rhs_cols); 198 } 199 200 // If we're transposing the output, then allocate a temporary buffer to 201 // store the output. Otherwise allocate the output directly. 202 Tensor* output = nullptr; 203 Tensor* matmul_result = nullptr; 204 Tensor output_transposed; 205 OP_REQUIRES_OK( 206 ctx, AllocateOutput(ctx, rank, batch_size, num_lhs_rows, num_rhs_cols, 207 this->transpose_output_, &output, 208 &output_transposed, &matmul_result)); 209 210 if (!this->transpose_a_) { 211 SparseDenseMatMulWithoutTransposedLHS( 212 ctx, batch_size, num_lhs_rows, *sparse_matrix_a, *rhs, matmul_result); 213 } else { // transpose_a_ == true 214 SparseDenseMatMulWithTransposedLHS(ctx, batch_size, num_lhs_rows, 215 num_lhs_cols, *sparse_matrix_a, *rhs, 216 matmul_result); 217 } 218 219 // Transpose (and conjugate) the output if necessary. 220 // Note that conjugate is only true if transpose is also true. 221 if (this->transpose_output_) { 222 OP_REQUIRES_OK( 223 ctx, TransposeAndConjugateAllocatedTensor( 224 ctx, output_transposed, this->conjugate_output_, output)); 225 } else if (this->conjugate_output_) { 226 functor::maybe_conj_inplace<CPUDevice, T>::run( 227 ctx->eigen_device<CPUDevice>(), output); 228 } 229 } 230 231 private: 232 // Allocates the output with the appropriate shape. Additionally, if 233 // transpose_output is True, allocates a temporary buffer with the transposed 234 // output. 'matmul_result' points to either output or output_transposed, based 235 // on whether transpose_output is True. AllocateOutput(OpKernelContext * ctx,const int32 rank,const int64 batch_size,const int64 num_rows,const int64 num_cols,const bool transpose_output,Tensor ** output,Tensor * output_transposed,Tensor ** matmul_result)236 Status AllocateOutput(OpKernelContext* ctx, const int32 rank, 237 const int64 batch_size, const int64 num_rows, 238 const int64 num_cols, const bool transpose_output, 239 Tensor** output, Tensor* output_transposed, 240 Tensor** matmul_result) { 241 TensorShape output_shape; 242 if (rank == 3) output_shape.AddDim(batch_size); 243 244 if (!transpose_output) { 245 output_shape.AppendShape({num_rows, num_cols}); 246 TF_RETURN_IF_ERROR(ctx->allocate_output(0, output_shape, output)); 247 *matmul_result = *output; 248 } else { 249 TensorShape output_transposed_shape = output_shape; 250 output_transposed_shape.AppendShape({num_rows, num_cols}); 251 output_shape.AppendShape({num_cols, num_rows}); 252 TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value, 253 output_transposed_shape, 254 output_transposed)); 255 TF_RETURN_IF_ERROR(ctx->allocate_output(0, output_shape, output)); 256 *matmul_result = output_transposed; 257 } 258 return Status::OK(); 259 } 260 261 // Returns an Eigen::Ref expression of a sparse sub-matrix from the given 262 // contiguous segment of rows of the CSR Sparse Matrix. GetSparseMatrixRef(const CSRSparseMatrix & csr_matrix,const int batch_index,const int64 row_begin,const int64 num_shard_rows,std::vector<int32> * row_ptrs)263 Eigen::Ref<const SparseMatrix> GetSparseMatrixRef( 264 const CSRSparseMatrix& csr_matrix, const int batch_index, 265 const int64 row_begin, const int64 num_shard_rows, 266 std::vector<int32>* row_ptrs) { 267 // Compute the row pointers of the sparse sub-matrix. 268 row_ptrs->resize(num_shard_rows + 1); 269 const int64 row_offset = 270 csr_matrix.row_pointers_vec(batch_index)(row_begin); 271 for (int64 row_idx = 0; row_idx <= num_shard_rows; ++row_idx) { 272 row_ptrs->at(row_idx) = 273 csr_matrix.row_pointers_vec(batch_index)(row_begin + row_idx) - 274 row_offset; 275 } 276 const int64 num_cols = 277 csr_matrix.dense_shape().vec<int64>()(csr_matrix.dims() - 1); 278 return Eigen::Map<const SparseMatrix>( 279 num_shard_rows /* num_rows */, num_cols /* num_cols */, 280 row_ptrs->at(num_shard_rows) /* total_nnz */, row_ptrs->data(), 281 csr_matrix.col_indices_vec(batch_index).data() + row_offset, 282 csr_matrix.values_vec<T>(batch_index).data() + row_offset); 283 } 284 285 // Sparse-Dense Matrix Multiplication between a CSRSparseMatrix (LHS) and a 286 // dense Tensor (RHS). SparseDenseMatMulWithoutTransposedLHS(OpKernelContext * ctx,const int64 batch_size,const int64 num_lhs_rows,const CSRSparseMatrix & lhs,const Tensor & rhs,Tensor * output)287 void SparseDenseMatMulWithoutTransposedLHS( 288 OpKernelContext* ctx, const int64 batch_size, const int64 num_lhs_rows, 289 const CSRSparseMatrix& lhs, const Tensor& rhs, Tensor* output) { 290 // Parallelize matrix multiplication across batch dimensions and across 291 // rows in each batch. 292 auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); 293 const int32 num_threads = worker_threads.num_threads; 294 const int64 block_size = 295 num_lhs_rows / std::max(kMaxShards, kNumShardsPerThread * num_threads); 296 const int64 num_rhs_rows = rhs.dim_size(rhs.dims() - 2); 297 const int64 num_rhs_cols = rhs.dim_size(rhs.dims() - 1); 298 worker_threads.workers->ParallelFor( 299 batch_size * num_lhs_rows /* total */, 300 thread::ThreadPool::SchedulingParams( 301 thread::ThreadPool::SchedulingStrategy:: 302 kFixedBlockSize /* strategy */, 303 absl::nullopt /* cost_per_unit */, block_size), 304 [&](int64 batch_and_row_begin, int64 batch_and_row_end) { 305 HandleBatchAndRowRange( 306 num_lhs_rows, batch_and_row_begin, batch_and_row_end, 307 [&](int64 batch_idx, int64 row_begin, int64 row_end) { 308 const int64 num_shard_rows = row_end - row_begin; 309 310 // Define an Eigen::SparseMatrix over the row range: 311 // [row_begin, row_end) of the CSR SparseMatrix A. 312 std::vector<int32> row_ptrs; 313 auto sparse_matrix = GetSparseMatrixRef( 314 lhs, batch_idx, row_begin, num_shard_rows, &row_ptrs); 315 316 // Map the corresponding rows of the rhs. 317 ConstMatrixMap rhs_map(rhs.flat<T>().data() + batch_idx * 318 num_rhs_rows * 319 num_rhs_cols, 320 num_rhs_rows, num_rhs_cols); 321 322 // Write to the corresponding rows of the output matrix. 323 MatrixMap output_map( 324 output->flat<T>().data() + 325 batch_idx * num_lhs_rows * num_rhs_cols + 326 row_begin * num_rhs_cols, 327 num_shard_rows, num_rhs_cols); 328 output_map.noalias() = sparse_matrix * rhs_map; 329 }); 330 }); 331 } 332 333 // Sparse-Dense Matrix Multiplication assuming the CSRSparseMatrix (LHS) is 334 // to be transposed before the operation. SparseDenseMatMulWithTransposedLHS(OpKernelContext * ctx,const int64 batch_size,const int64 num_lhs_rows,const int64 num_lhs_cols,const CSRSparseMatrix & lhs,const Tensor & rhs,Tensor * output)335 void SparseDenseMatMulWithTransposedLHS(OpKernelContext* ctx, 336 const int64 batch_size, 337 const int64 num_lhs_rows, 338 const int64 num_lhs_cols, 339 const CSRSparseMatrix& lhs, 340 const Tensor& rhs, Tensor* output) { 341 auto device = ctx->eigen_device<CPUDevice>(); 342 auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); 343 const int32 num_threads = worker_threads.num_threads; 344 const int64 num_rhs_rows = rhs.dim_size(rhs.dims() - 2); 345 const int64 num_rhs_cols = rhs.dim_size(rhs.dims() - 1); 346 // Usually, we want to avoid transposing the sparse matrix A since it may be 347 // an expensive operation. Instead, we use the identity (A^T B) = (B^T A)^T. 348 // We don't actually transpose B or the output because it is more convenient 349 // to have them in column major form. 350 // 351 // However, if A is hypersparse and B and C are huge, transposing A will be 352 // cheaper. In the future, we should have a cost model estimating the cost 353 // of transposing all matrices (A, B, C) to decide which variant to use. 354 355 // Each thread writes to its own copy of the matrix product. These 356 // `num_threads` copies are summed together to obtain the final result. 357 Tensor matmul_result_buffer; 358 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, 359 TensorShape({num_threads + 1, 360 output->NumElements()}), 361 &matmul_result_buffer)); 362 functor::SetZeroFunctor<CPUDevice, T> set_zero; 363 set_zero(device, matmul_result_buffer.flat<T>()); 364 365 // Parallelize matrix multiplication across batch dimensions and across 366 // columns of A^T in each batch. These correspond to rows of A. 367 const int64 block_size = 368 num_lhs_cols / std::max(kMaxShards, kNumShardsPerThread * num_threads); 369 worker_threads.workers->ParallelForWithWorkerId( 370 batch_size * num_lhs_cols /* total */, 371 thread::ThreadPool::SchedulingParams( 372 thread::ThreadPool::SchedulingStrategy:: 373 kFixedBlockSize /* strategy */, 374 absl::nullopt /* cost_per_unit */, block_size), 375 [&](int64 batch_and_row_begin, int64 batch_and_row_end, int tid) { 376 HandleBatchAndRowRange( 377 num_lhs_cols, batch_and_row_begin, batch_and_row_end, 378 [&](int64 batch_idx, int64 row_begin, int64 row_end) { 379 const int64 num_shard_rows = row_end - row_begin; 380 381 // Define a new sparse sub-matrix from the row range 382 // [row_begin, row_end) of the sparse matrix A. 383 std::vector<int32> row_ptrs; 384 auto sparse_matrix = GetSparseMatrixRef( 385 lhs, batch_idx, row_begin, num_shard_rows, &row_ptrs); 386 387 // Map the corresponding `num_shard_rows` columns of B^T. 388 // This is the same as taking the `num_shard_rows` rows of B. 389 ConstMatrixMap b_dense_map( 390 rhs.flat<T>().data() + 391 batch_idx * num_rhs_rows * num_rhs_cols + 392 row_begin * num_rhs_cols, 393 num_shard_rows, num_rhs_cols); 394 395 // Map to the corresponding rows of the output. 396 MatrixMap output_map( 397 matmul_result_buffer.flat<T>().data() + 398 tid * batch_size * num_lhs_rows * num_rhs_cols + 399 batch_idx * num_lhs_rows * num_rhs_cols, 400 num_lhs_rows, num_rhs_cols); 401 402 // Compute the product C^T = B^T * A; restricted to the row 403 // range in the current shard. 404 if (this->conjugate_a_) { 405 output_map.transpose().noalias() += 406 b_dense_map.transpose() * sparse_matrix.conjugate(); 407 } else { 408 output_map.transpose().noalias() += 409 b_dense_map.transpose() * sparse_matrix; 410 } 411 }); 412 }); 413 414 // Sum across each thread's matmul result. 415 using Reducer = Eigen::internal::SumReducer<T>; 416 using Index = typename TTypes<T>::Tensor::Index; 417 output->flat<T>().device(device) = matmul_result_buffer.matrix<T>().reduce( 418 Eigen::array<Index, 1>({0}), Reducer()); 419 } 420 421 // Given a range [batch_and_row_begin, batch_and_row_end) which is a 422 // contiguous subset of [0, num_rows * batch_size), calls the function 423 // fn(batch_idx, row_begin, row_end) for each batch index 424 // and the row range [row_begin, row_end) contained in the batch. HandleBatchAndRowRange(const int64 num_rows,const int64 batch_and_row_begin,const int64 batch_and_row_end,const std::function<void (int64,int64,int64)> & fn)425 void HandleBatchAndRowRange( 426 const int64 num_rows, const int64 batch_and_row_begin, 427 const int64 batch_and_row_end, 428 const std::function<void(int64, int64, int64)>& fn) { 429 // Obtain the batch indices overlapping with the current shard. 430 const int64 batch_begin = batch_and_row_begin / num_rows; 431 const int64 batch_end_inclusive = batch_and_row_end / num_rows; 432 433 for (int64 batch_idx = batch_begin; batch_idx <= batch_end_inclusive; 434 ++batch_idx) { 435 // Find the contiguous set of rows which are contained in this shard as 436 // well as the current batch. We intersect with interval [batch_idx * 437 // num_rows, (batch_idx + 1) * num_rows) which denotes the current batch. 438 const int64 current_batch_row_begin = 439 std::max(batch_and_row_begin, batch_idx * num_rows); 440 const int64 current_batch_row_end = 441 std::min(batch_and_row_end, (batch_idx + 1) * num_rows); 442 443 const int64 row_begin = current_batch_row_begin % num_rows; 444 const int64 num_shard_rows = 445 current_batch_row_end - current_batch_row_begin; 446 // Edge case for when current_batch_row_end is the first index of a new 447 // row. 448 if (num_shard_rows == 0) continue; 449 450 fn(batch_idx, row_begin, row_begin + num_shard_rows); 451 } 452 } 453 454 // Transposes (and optionally, conjugates) a given Tensor. Also allocates the 455 // required memory for the output Tensor. TransposeAndConjugateTensor(OpKernelContext * ctx,const Tensor & input,bool conjugate,Tensor * output)456 Status TransposeAndConjugateTensor(OpKernelContext* ctx, const Tensor& input, 457 bool conjugate, Tensor* output) { 458 TensorShape transposed_shape = input.shape(); 459 transposed_shape.set_dim(input.dims() - 1, 460 input.dim_size(input.dims() - 2)); 461 transposed_shape.set_dim(input.dims() - 2, 462 input.dim_size(input.dims() - 1)); 463 TF_RETURN_IF_ERROR( 464 ctx->allocate_temp(DataTypeToEnum<T>::value, transposed_shape, output)); 465 return TransposeAndConjugateAllocatedTensor(ctx, input, conjugate, output); 466 } 467 468 // Transposes (and optionally, conjugates) a given Tensor. The output should 469 // be already allocated. TransposeAndConjugateAllocatedTensor(OpKernelContext * ctx,const Tensor & input,bool conjugate,Tensor * output)470 Status TransposeAndConjugateAllocatedTensor(OpKernelContext* ctx, 471 const Tensor& input, 472 bool conjugate, Tensor* output) { 473 if (conjugate) { 474 TF_RETURN_IF_ERROR(DoConjugateMatrixTranspose( 475 ctx->eigen_device<CPUDevice>(), input, output)); 476 } else { 477 TF_RETURN_IF_ERROR( 478 DoMatrixTranspose(ctx->eigen_device<CPUDevice>(), input, output)); 479 } 480 return Status::OK(); 481 } 482 }; 483 484 // GPU Kernel to compute sparse-dense matrix multiplication. 485 template <typename T> 486 class CSRMatMulGPUOp : public CSRMatMulOp<GPUDevice, T> { 487 using SparseMatrix = Eigen::SparseMatrix<T, Eigen::RowMajor>; 488 using Matrix = 489 Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>; 490 using ConstMatrixMap = Eigen::Map<const Matrix>; 491 using MatrixMap = Eigen::Map<Matrix>; 492 493 public: CSRMatMulGPUOp(OpKernelConstruction * c)494 explicit CSRMatMulGPUOp(OpKernelConstruction* c) 495 : CSRMatMulOp<GPUDevice, T>(c) {} 496 ~CSRMatMulGPUOp()497 ~CSRMatMulGPUOp() override {} 498 Compute(OpKernelContext * ctx)499 void Compute(OpKernelContext* ctx) final { 500 const CSRSparseMatrix* a_matrix; 501 OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &a_matrix)); 502 const Tensor& b_t = ctx->input(1); 503 504 int rank; 505 int64 batch_size; 506 OP_REQUIRES_OK(ctx, 507 this->ValidateInputs(*a_matrix, b_t, &rank, &batch_size)); 508 509 const Tensor& a_dense_shape_t = a_matrix->dense_shape(); 510 TensorShape a_dense_tensor_shape; 511 auto a_dense_shape = a_dense_shape_t.vec<int64>(); 512 OP_REQUIRES_OK( 513 ctx, TensorShapeUtils::MakeShape(a_dense_shape, &a_dense_tensor_shape)); 514 515 const int row_dim = (rank == 2) ? 0 : 1; 516 const int64 a_outer_dim = a_dense_tensor_shape.dim_size( 517 this->transpose_a_ ? row_dim + 1 : row_dim); 518 const int64 b_inner_dim = 519 b_t.shape().dim_size(this->transpose_b_ ? row_dim + 1 : row_dim); 520 const int64 b_outer_dim = 521 b_t.dim_size(this->transpose_b_ ? row_dim : row_dim + 1); 522 const int64 b_slice_size = b_inner_dim * b_outer_dim; 523 524 TensorShape c_shape; 525 if (rank == 3) c_shape.AddDim(batch_size); 526 if (this->transpose_output_) { 527 c_shape.AddDim(b_outer_dim); 528 c_shape.AddDim(a_outer_dim); 529 } else { 530 c_shape.AddDim(a_outer_dim); 531 c_shape.AddDim(b_outer_dim); 532 } 533 534 const int64 c_matrix_lhs = c_shape.dim_size(row_dim); 535 const int64 c_matrix_rhs = c_shape.dim_size(row_dim + 1); 536 const int64 c_slice_size = c_matrix_lhs * c_matrix_rhs; 537 Tensor* c_t; 538 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, c_shape, &c_t)); 539 540 const GPUDevice& d = ctx->eigen_device<GPUDevice>(); 541 bool use_matrix_vector_multiply = (b_outer_dim == 1); 542 #if TENSORFLOW_USE_ROCM 543 // ROCm hipsparse does not implement csrmv with transposed input a 544 use_matrix_vector_multiply = 545 use_matrix_vector_multiply && !this->transpose_a_; 546 #endif 547 if (use_matrix_vector_multiply) { 548 // Call matrix-vector multiply if b is a vector. 549 TTypes<int64>::ConstVec a_dense_shape_comp(a_dense_shape.data() + row_dim, 550 2); 551 Tensor b_conj_t; 552 const T* b_base_ptr = b_t.template flat<T>().data(); 553 bool conjugate_a = this->conjugate_a_; 554 bool conjugate_output = this->conjugate_output_; 555 if (this->conjugate_b_) { 556 if (conjugate_a) { 557 // In this case we can use the identity 558 // conj(a) * conj(b) = conj(a * b) 559 // instead of creating a conjugated copy of b. 560 conjugate_a = false; 561 conjugate_output = !conjugate_output; 562 } else { 563 OP_REQUIRES_OK( 564 ctx, ctx->forward_input_or_allocate_temp( 565 {1}, DataTypeToEnum<T>::value, b_t.shape(), &b_conj_t)); 566 functor::maybe_conj<GPUDevice, T>::run(d, b_t, &b_conj_t); 567 b_base_ptr = b_conj_t.template flat<T>().data(); 568 } 569 } 570 571 functor::CSRSparseMatrixMatVec<GPUDevice, T> csr_spmv(this->transpose_a_, 572 conjugate_a); 573 for (int i = 0; i < batch_size; ++i) { 574 auto a_row_ptr = a_matrix->row_pointers_vec(i); 575 auto a_col_ind = a_matrix->col_indices_vec(i); 576 auto a_values = a_matrix->values_vec<T>(i); 577 ConstCSRComponent<T> a_comp{a_row_ptr, a_col_ind, a_values, 578 a_dense_shape_comp}; 579 const T* b_i = b_base_ptr + i * b_slice_size; 580 T* c_i = &c_t->template flat<T>()(i * c_slice_size); 581 Status s = csr_spmv.Compute(ctx, a_comp, b_i, c_i); 582 OP_REQUIRES_OK(ctx, s); 583 } 584 if (conjugate_output) { 585 functor::maybe_conj_inplace<GPUDevice, T>::run(d, c_t); 586 } 587 return; 588 } 589 590 functor::CSRSparseMatrixMatMul<GPUDevice, T> csr_spmmadd( 591 this->transpose_output_); 592 593 Tensor c_mat_col_major_t; 594 if (!this->transpose_output_) { 595 // If transpose_output is false, we'll need to transpose the (col 596 // major) output of the csrgemm call to get proper (row-major) 597 // output. Which means we need to keep a temporary buffer to 598 // store the intermediate gemm output. 599 TensorShape c_mat_col_major_shape; 600 if (rank == 2) { 601 c_mat_col_major_shape = TensorShape({c_matrix_rhs, c_matrix_lhs}); 602 } else { 603 c_mat_col_major_shape = 604 TensorShape({batch_size, c_matrix_rhs, c_matrix_lhs}); 605 } 606 OP_REQUIRES_OK( 607 ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, 608 c_mat_col_major_shape, &c_mat_col_major_t)); 609 } 610 611 // If transpose_output is true, return the direct (column-major i.e., 612 // transposed) output of the csrgemm call. Otherwise we'll need 613 // to transpose it to row major format. 614 auto c_mat_col_major = (this->transpose_output_) 615 ? c_t->flat<T>() 616 : c_mat_col_major_t.flat<T>(); 617 618 // Possibly transpose a. 619 const CSRSparseMatrix* a_input_matrix; 620 // If we need to transpose a, we will store the result temporarily 621 // in the object below. 622 CSRSparseMatrix a_matrix_transposed; 623 if (!this->transpose_a_) { 624 a_input_matrix = a_matrix; 625 } else { 626 functor::CSRSparseMatrixTranspose<GPUDevice, T> transpose; 627 OP_REQUIRES_OK(ctx, transpose(ctx, this->conjugate_a_, *a_matrix, 628 &a_matrix_transposed)); 629 a_input_matrix = &a_matrix_transposed; 630 } 631 632 auto a_input_dense_shape = a_input_matrix->dense_shape().vec<int64>(); 633 634 // Possibly transpose b. 635 Tensor b_t_input; 636 if (!this->transpose_b_) { 637 b_t_input = b_t; 638 } else { 639 TensorShape b_t_transposed_shape; 640 if (rank == 3) { 641 b_t_transposed_shape.AddDim(batch_size); 642 } 643 b_t_transposed_shape.AddDim(b_t.dim_size(row_dim + 1)); 644 b_t_transposed_shape.AddDim(b_t.dim_size(row_dim)); 645 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, 646 b_t_transposed_shape, &b_t_input)); 647 const GPUDevice& d = ctx->eigen_device<GPUDevice>(); 648 if (this->conjugate_b_) { 649 OP_REQUIRES_OK(ctx, DoConjugateMatrixTranspose(d, b_t /*input*/, 650 &b_t_input /*output*/)); 651 } else { 652 OP_REQUIRES_OK( 653 ctx, DoMatrixTranspose(d, b_t /*input*/, &b_t_input /*output*/)); 654 } 655 } 656 657 // Dense shape of a batch component of A. 658 TTypes<int64>::ConstVec a_input_dense_shape_comp( 659 a_input_dense_shape.data() + row_dim, 2); 660 661 auto b = b_t_input.flat<T>(); 662 663 for (int i = 0; i < batch_size; ++i) { 664 auto a_row_ptr = a_input_matrix->row_pointers_vec(i); 665 auto a_col_ind = a_input_matrix->col_indices_vec(i); 666 auto a_values = a_input_matrix->values_vec<T>(i); 667 typename TTypes<T>::UnalignedConstMatrix b_i(b.data() + i * b_slice_size, 668 {b_inner_dim, b_outer_dim}); 669 typename TTypes<T>::UnalignedMatrix c_mat_col_major_i( 670 c_mat_col_major.data() + i * c_slice_size, 671 {c_matrix_lhs, c_matrix_rhs}); 672 ConstCSRComponent<T> a_comp{a_row_ptr, a_col_ind, a_values, 673 a_input_dense_shape_comp}; 674 Status s = csr_spmmadd.Compute(ctx, a_comp, b_i, c_mat_col_major_i); 675 OP_REQUIRES_OK(ctx, s); 676 } 677 678 if (!this->transpose_output_) { 679 // We need to return values in row major format, so transpose 680 // the column-major values in c_mat_col_major_t to row-major output c_t. 681 OP_REQUIRES_OK(ctx, DoMatrixTranspose(d, /*input=*/c_mat_col_major_t, 682 /*output=*/c_t)); 683 } 684 if (this->conjugate_output_) { 685 functor::maybe_conj_inplace<GPUDevice, T>::run(d, c_t); 686 } 687 } 688 }; 689 690 #define REGISTER_CPU(T) \ 691 REGISTER_KERNEL_BUILDER( \ 692 Name("SparseMatrixMatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 693 CSRMatMulCPUOp<T>); 694 695 REGISTER_CPU(float) 696 REGISTER_CPU(double) 697 REGISTER_CPU(complex64) 698 REGISTER_CPU(complex128) 699 700 #undef REGISTER_CPU 701 702 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 703 704 #define REGISTER_GPU(T) \ 705 REGISTER_KERNEL_BUILDER( \ 706 Name("SparseMatrixMatMul").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ 707 CSRMatMulGPUOp<T>); 708 709 REGISTER_GPU(float) 710 REGISTER_GPU(double) 711 #if GOOGLE_CUDA 712 REGISTER_GPU(complex64) 713 REGISTER_GPU(complex128) 714 #endif 715 716 #undef REGISTER_GPU 717 718 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 719 720 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 721 722 namespace functor { 723 724 namespace { 725 726 // GPUDataType<T>::type translates from a C++ type (e.g. float) to a 727 // GPUDataType_t (e.g. CUDA_R_32F). 728 template <typename T> 729 struct GPUDataType; 730 731 // GPUDataType templates are currently not instantiated in the ROCm flow 732 // So leaving out the #elif TENSORFLOW_USE_ROCM blocks for now 733 // hipblas library is not (yet) being pulled in via rocm_configure.bzl 734 // so cannot reference tyeps from hipblas headers here 735 template <> 736 struct GPUDataType<Eigen::half> { 737 #if GOOGLE_CUDA 738 static constexpr cudaDataType_t type = CUDA_R_16F; 739 #endif 740 }; 741 742 template <> 743 struct GPUDataType<float> { 744 #if GOOGLE_CUDA 745 static constexpr cudaDataType_t type = CUDA_R_32F; 746 #endif 747 }; 748 749 template <> 750 struct GPUDataType<std::complex<float>> { 751 #if GOOGLE_CUDA 752 static constexpr cudaDataType_t type = CUDA_C_32F; 753 #endif 754 }; 755 756 template <> 757 struct GPUDataType<double> { 758 #if GOOGLE_CUDA 759 static constexpr cudaDataType_t type = CUDA_R_64F; 760 #endif 761 }; 762 763 template <> 764 struct GPUDataType<std::complex<double>> { 765 #if GOOGLE_CUDA 766 static constexpr cudaDataType_t type = CUDA_C_64F; 767 #endif 768 }; 769 770 } // namespace 771 772 template <typename T> 773 class CSRSparseMatrixMatMul<GPUDevice, T> { 774 public: CSRSparseMatrixMatMul(const bool transpose_output)775 explicit CSRSparseMatrixMatMul(const bool transpose_output) 776 : transpose_output_(transpose_output) {} 777 Compute(OpKernelContext * ctx,const ConstCSRComponent<T> & a,typename TTypes<T>::UnalignedConstMatrix b,typename TTypes<T>::UnalignedMatrix c)778 Status Compute(OpKernelContext* ctx, const ConstCSRComponent<T>& a, 779 typename TTypes<T>::UnalignedConstMatrix b, 780 typename TTypes<T>::UnalignedMatrix c) { 781 GpuSparse cuda_sparse(ctx); 782 TF_RETURN_IF_ERROR(cuda_sparse.Initialize()); 783 { 784 // Use Csrmm/SpMM to calculate: 785 // C = alpha * op(A) * op(B) + beta * C 786 // where alpha = 1.0, beta = 0.0, A is sparse and B and C are dense. 787 // Note that Csrmm/Spmm assumes B and C are in column-major form; so we 788 // use transB == true, and manually transpose the output in place 789 // using blas<t>geam. 790 // TODO(ebrevdo,rmlarsen): Add support for transposition and adjoint. 791 792 // Create alpha and beta scalars; alpha = 1.0, beta = 0.0 793 // TODO(ebrevdo,rmlarsen): Add support for non-trivial alpha and beta. 794 const T alpha = 1; 795 const T beta = 0; 796 797 // A is (m, k), Bt is (ldb, k) and Ct is (ldc, n) 798 const int k = b.dimension(0); 799 DCHECK_EQ(k, a.dense_shape_host(1)); 800 801 // If transpose_output_ is true, then the c matrix we receive 802 // here is the direct row major output (into which we will store 803 // csrgemm's col major output). Otherwise it's a 804 // temporary tensor that will store the column major output that 805 // will eventually be transposed. 806 const int m = c.dimension(transpose_output_ ? 1 : 0); 807 const int n = c.dimension(transpose_output_ ? 0 : 1); 808 DCHECK_EQ(m, a.dense_shape_host(0)); 809 DCHECK_EQ(n, b.dimension(1)); 810 const int nnz = a.values.size(); 811 DCHECK_EQ(nnz, a.col_ind.size()); 812 813 // ldb: leading dimension of B. If op(B)=B, it must be at least max(1, k) 814 // if op(A) = A and at least max (1, m) otherwise. If op(B) != B, it must 815 // be at least max(1, n). 816 const int ldb = n; 817 // ldc: leading dimension of C. It must be at least max(1, m) if 818 // op(A) = A and at least max(1, k) otherwise. 819 const int ldc = m; 820 821 // transA must be non-transpose if transB is transpose (cusparse 822 // limitation). 823 #if GOOGLE_CUDA 824 const gpusparseOperation_t transA = CUSPARSE_OPERATION_NON_TRANSPOSE; 825 #elif TENSORFLOW_USE_ROCM 826 const gpusparseOperation_t transA = HIPSPARSE_OPERATION_NON_TRANSPOSE; 827 #endif 828 829 // transB: b is row-major, and cusparse requires col-major b (or 830 // equivalently transB == transpose). this version is actually more 831 // efficient. 832 #if GOOGLE_CUDA && CUDA_VERSION >= 10020 833 834 const gpusparseOperation_t transB = CUSPARSE_OPERATION_TRANSPOSE; 835 gpusparseSpMatDescr_t matA; 836 gpusparseDnMatDescr_t matB, matC; 837 838 // NOTE: the following APIs are not available in ROCM 839 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsr( 840 &matA, m, k, nnz, const_cast<int*>(a.row_ptr.data()), 841 const_cast<int*>(a.col_ind.data()), const_cast<T*>(a.values.data()), 842 CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, 843 GPUDataType<T>::type)); 844 845 TF_RETURN_IF_GPUSPARSE_ERROR( 846 cusparseCreateDnMat(&matB, n, k, ldb, const_cast<T*>(b.data()), 847 GPUDataType<T>::type, CUSPARSE_ORDER_COL)); 848 849 TF_RETURN_IF_GPUSPARSE_ERROR( 850 cusparseCreateDnMat(&matC, m, n, ldc, c.data(), GPUDataType<T>::type, 851 CUSPARSE_ORDER_COL)); 852 853 size_t bufferSize = 0; 854 TF_RETURN_IF_ERROR(cuda_sparse.SpMMBufferSize( 855 transA, transB, &alpha, matA, matB, &beta, matC, 856 CUSPARSE_MM_ALG_DEFAULT, &bufferSize)); 857 858 Tensor buffer; 859 TF_RETURN_IF_ERROR(ctx->allocate_temp( 860 DT_INT8, TensorShape({static_cast<int64>(bufferSize)}), &buffer)); 861 DCHECK(buffer.flat<int8>().data() != nullptr); 862 863 TF_RETURN_IF_ERROR(cuda_sparse.SpMM(transA, transB, &alpha, matA, matB, 864 &beta, matC, CUSPARSE_MM_ALG_DEFAULT, 865 buffer.flat<int8>().data())); 866 867 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnMat(matB)); 868 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnMat(matC)); 869 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroySpMat(matA)); 870 871 #else 872 873 #if GOOGLE_CUDA 874 875 const gpusparseOperation_t transB = CUSPARSE_OPERATION_TRANSPOSE; 876 877 gpusparseMatDescr_t descrA; 878 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA)); 879 TF_RETURN_IF_GPUSPARSE_ERROR( 880 cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL)); 881 TF_RETURN_IF_GPUSPARSE_ERROR( 882 cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO)); 883 884 #elif TENSORFLOW_USE_ROCM 885 886 const gpusparseOperation_t transB = HIPSPARSE_OPERATION_TRANSPOSE; 887 888 gpusparseMatDescr_t descrA; 889 TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateMatDescr(&descrA)); 890 TF_RETURN_IF_GPUSPARSE_ERROR( 891 wrap::hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL)); 892 TF_RETURN_IF_GPUSPARSE_ERROR( 893 wrap::hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO)); 894 #endif // GOOGLE_CUDA 895 896 TF_RETURN_IF_ERROR( 897 cuda_sparse.Csrmm(transA, transB, m, n, k, nnz, &alpha, descrA, 898 a.values.data(), a.row_ptr.data(), a.col_ind.data(), 899 b.data(), ldb, &beta, c.data(), ldc)); 900 901 #endif // GOOGLE_CUDA && CUDA_VERSION >= 10020 902 } 903 904 return Status::OK(); 905 } 906 907 private: 908 bool transpose_output_; 909 }; 910 911 template <typename T> 912 class CSRSparseMatrixMatVec<GPUDevice, T> { 913 public: CSRSparseMatrixMatVec(bool transpose_a,bool conjugate_a)914 CSRSparseMatrixMatVec(bool transpose_a, bool conjugate_a) 915 : transA_(TransposeAndConjugateToGpuSparseOp(transpose_a, conjugate_a, 916 &status_)) {} 917 Compute(OpKernelContext * ctx,const ConstCSRComponent<T> & a,const T * x,T * y)918 Status Compute(OpKernelContext* ctx, const ConstCSRComponent<T>& a, 919 const T* x, T* y) { 920 TF_RETURN_IF_ERROR(status_); 921 GpuSparse cuda_sparse(ctx); 922 TF_RETURN_IF_ERROR(cuda_sparse.Initialize()); 923 { 924 // Use Csrmv to calculate: 925 // y = alpha * op(A) * x + beta * y 926 // where alpha = 1.0, beta = 0.0, A is a sparse matrix and x and y are 927 // dense vectors. 928 929 // Create alpha and beta scalars; alpha = 1.0, beta = 0.0 930 // TODO(rmlarsen,ebrevdo): Add support for general alpha, beta. 931 const T alpha = 1; 932 const T beta = 0; 933 934 #if GOOGLE_CUDA && CUDA_VERSION < 10020 935 gpusparseMatDescr_t descrA; 936 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA)); 937 TF_RETURN_IF_GPUSPARSE_ERROR( 938 cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL)); 939 TF_RETURN_IF_GPUSPARSE_ERROR( 940 cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO)); 941 #elif TENSORFLOW_USE_ROCM 942 gpusparseMatDescr_t descrA; 943 TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateMatDescr(&descrA)); 944 TF_RETURN_IF_GPUSPARSE_ERROR( 945 wrap::hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL)); 946 TF_RETURN_IF_GPUSPARSE_ERROR( 947 wrap::hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO)); 948 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 949 950 const int m = a.dense_shape_host(0); 951 const int n = a.dense_shape_host(1); 952 const int nnz = a.values.size(); 953 DCHECK_EQ(nnz, a.col_ind.size()); 954 #if GOOGLE_CUDA && (CUDA_VERSION >= 10020) 955 TF_RETURN_IF_ERROR(cuda_sparse.Csrmv(transA_, m, n, nnz, &alpha, 956 a.values.data(), a.row_ptr.data(), 957 a.col_ind.data(), x, &beta, y)); 958 #else 959 TF_RETURN_IF_ERROR(cuda_sparse.Csrmv(transA_, m, n, nnz, &alpha, descrA, 960 a.values.data(), a.row_ptr.data(), 961 a.col_ind.data(), x, &beta, y)); 962 #endif 963 } 964 965 return Status::OK(); 966 } 967 968 private: 969 Status status_; 970 const gpusparseOperation_t transA_; 971 }; 972 973 } // namespace functor 974 975 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 976 977 } // namespace tensorflow 978