1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define EIGEN_USE_THREADS
17
18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
19 #define EIGEN_USE_GPU
20 #endif
21
22 #include <memory>
23 #include <numeric>
24
25 #include "third_party/eigen3/Eigen/SparseCore"
26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27 #include "tensorflow/core/framework/op.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 #include "tensorflow/core/framework/tensor_types.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/framework/variant_op_registry.h"
33 #include "tensorflow/core/kernels/dense_update_functor.h"
34 #include "tensorflow/core/kernels/sparse/kernels.h"
35 #include "tensorflow/core/kernels/sparse/sparse_matrix.h"
36 #include "tensorflow/core/util/work_sharder.h"
37
38 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
39 #include "tensorflow/core/util/cuda_solvers.h"
40 #include "tensorflow/core/util/cuda_sparse.h"
41 #endif
42
43 namespace tensorflow {
44
45 typedef Eigen::ThreadPoolDevice CPUDevice;
46 typedef Eigen::GpuDevice GPUDevice;
47
48 namespace {
49
50 // Swaps the dim sizes at two given dimensions of a TensorShape.
51 // Callers are responsible for making sure the given dimensions are within the
52 // valid dimension range of the TensorShape.
SwapDimSizes(const int dim_a,const int dim_b,TensorShape * shape)53 void SwapDimSizes(const int dim_a, const int dim_b, TensorShape* shape) {
54 const int64 size_a = shape->dim_size(dim_a);
55 const int64 size_b = shape->dim_size(dim_b);
56 shape->set_dim(dim_a, size_b);
57 shape->set_dim(dim_b, size_a);
58 }
59
60 } // namespace
61
62 // Op to compute the matrix multiplication of two CSR Sparse Matrices.
63 //
64 // Implements a CPU kernel to perform matrix multiplication using Eigen
65 // SparseMatrix and its Sparse-Sparse matmul. Supports transposing and
66 // adjointing on the fly for both the inputs without actually constructing the
67 // transpose or adjoint.
68 //
69 // This implementation does not support broadcasting. Hence both the input
70 // CSRSparseMatrices must have the same rank. (Either rank 2 or rank 3).
71 //
72 // The output sparse have numeric (non-structural) zeros.
73 // TODO(anudhyan): Consider exposing whether to prune zeros as an attribute in
74 // the op's interface.
75 //
76 // If multiple threads are available, we parallelize across multiple batches
77 // using Eigen ThreadPool. Within a single batch, we run in single threaded mode
78 // because Eigen's Sparse-Sparse matmul doesn't support multithreading.
79 //
80 // TODO(b/126472741): Due to the multiple batches of a 3D CSRSparseMatrix being
81 // laid out in contiguous memory, this implementation allocates memory to store
82 // a temporary copy of the matrix product. Consequently, it uses roughly twice
83 // the amount of memory that it needs to. This may cause a memory blowup for
84 // sparse matrices with a high number of non-zero elements.
85 template <typename T>
86 class CSRSparseMatMulCPUOp : public OpKernel {
87 using SparseMatrix = Eigen::SparseMatrix<T, Eigen::RowMajor>;
88
89 public:
CSRSparseMatMulCPUOp(OpKernelConstruction * c)90 explicit CSRSparseMatMulCPUOp(OpKernelConstruction* c) : OpKernel(c) {
91 OP_REQUIRES_OK(c, c->GetAttr("transpose_a", &transpose_a_));
92 OP_REQUIRES_OK(c, c->GetAttr("transpose_b", &transpose_b_));
93 OP_REQUIRES_OK(c, c->GetAttr("adjoint_a", &adjoint_a_));
94 OP_REQUIRES(c, !(adjoint_a_ && transpose_a_),
95 errors::InvalidArgument(
96 "Only one of adjoint_a and transpose_a may be true."));
97 OP_REQUIRES_OK(c, c->GetAttr("adjoint_b", &adjoint_b_));
98 OP_REQUIRES(c, !(adjoint_b_ && transpose_b_),
99 errors::InvalidArgument(
100 "Only one of adjoint_b and transpose_b may be true."));
101 }
102
Compute(OpKernelContext * ctx)103 void Compute(OpKernelContext* ctx) final {
104 const CSRSparseMatrix* input_matrix_a;
105 const CSRSparseMatrix* input_matrix_b;
106 // TODO(anudhyan): Factor out common validation logic in CPU and GPU Ops
107 // into a common base class.
108 OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &input_matrix_a));
109 OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 1, &input_matrix_b));
110 OP_REQUIRES(ctx, input_matrix_a->dtype() == DataTypeToEnum<T>::value,
111 errors::InvalidArgument(
112 "dtype of a is not equal to 'type': ",
113 DataTypeString(input_matrix_a->dtype()), " vs. ",
114 DataTypeString(DataTypeToEnum<T>::value)));
115 OP_REQUIRES(ctx, input_matrix_b->dtype() == DataTypeToEnum<T>::value,
116 errors::InvalidArgument(
117 "dtype of b is not equal to 'type': ",
118 DataTypeString(input_matrix_b->dtype()), " vs. ",
119 DataTypeString(DataTypeToEnum<T>::value)));
120 OP_REQUIRES(ctx,
121 input_matrix_a->batch_size() == input_matrix_b->batch_size(),
122 errors::InvalidArgument(
123 "Batch sizes of A and B do not agree. Batch sizes are: ",
124 input_matrix_a->batch_size(), " vs. ",
125 input_matrix_b->batch_size()));
126
127 // Validate input_matrix_a's and input_matrix_b's shapes
128 TensorShape a_shape;
129 TensorShape b_shape;
130 OP_REQUIRES_OK(ctx,
131 TensorShapeUtils::MakeShape(
132 input_matrix_a->dense_shape().vec<int64>(), &a_shape));
133 OP_REQUIRES_OK(ctx,
134 TensorShapeUtils::MakeShape(
135 input_matrix_b->dense_shape().vec<int64>(), &b_shape));
136
137 const int rank = a_shape.dims();
138 const int row_dim = (rank == 2) ? 0 : 1;
139 if (transpose_a_ || adjoint_a_)
140 SwapDimSizes(row_dim, row_dim + 1, &a_shape);
141 if (transpose_b_ || adjoint_b_)
142 SwapDimSizes(row_dim, row_dim + 1, &b_shape);
143
144 OP_REQUIRES(
145 ctx, a_shape.dim_size(row_dim + 1) == b_shape.dim_size(row_dim),
146 errors::InvalidArgument(
147 "Inner product dimensions of A and B do not agree. Shapes are: ",
148 a_shape.DebugString(), " vs. ", b_shape.DebugString()));
149
150 // Infer the output shape of the matrix product.
151 // TODO(ebrevdo): MatMul support for broadcasting at least in the
152 // batch dimension.
153 const int batch_size = input_matrix_a->batch_size();
154 Tensor output_shape(cpu_allocator(), DT_INT64, TensorShape({rank}));
155 auto output_shape_vec = output_shape.vec<int64>();
156 if (rank == 3) output_shape_vec(0) = batch_size;
157 output_shape_vec(row_dim) = a_shape.dim_size(row_dim);
158 output_shape_vec(row_dim + 1) = b_shape.dim_size(row_dim + 1);
159
160 // Set batch pointers.
161 Tensor batch_ptr(cpu_allocator(), DT_INT32, TensorShape({batch_size + 1}));
162 auto batch_ptr_vec = batch_ptr.vec<int32>();
163 batch_ptr_vec(0) = 0;
164
165 // Store intermediate matrix products for each batch.
166 // TODO(b/126472741): For a single batch, consider reusing the
167 // SparseMatrices' buffers to construct the CSRSparseMatrix to prevent 2x
168 // memory usage.
169 std::vector<SparseMatrix> output_matrices(batch_size);
170
171 auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
172 // Estimate the cost per batch per as num_output_rows times the product of
173 // average number of nonzeros per row.
174 const int64 num_output_rows = output_shape_vec(row_dim);
175 const double avg_nnz_per_row_a =
176 input_matrix_a->total_nnz() /
177 static_cast<double>(a_shape.dim_size(row_dim) * batch_size);
178 const double avg_nnz_per_row_b =
179 input_matrix_b->total_nnz() /
180 static_cast<double>(b_shape.dim_size(row_dim) * batch_size);
181 const int64 matmul_cost_per_batch =
182 num_output_rows * (avg_nnz_per_row_a * avg_nnz_per_row_b);
183
184 // Parallelize matrix multiplication across batches.
185 Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
186 matmul_cost_per_batch, [&](int64 batch_begin, int64 batch_end) {
187 for (int64 batch_idx = batch_begin; batch_idx < batch_end;
188 ++batch_idx) {
189 // For each batch, map the CSRSparseMatrix as Eigen SparseMatrix
190 // without copying the underlying data.
191 auto a_ref = GetSparseMatrixRef(*input_matrix_a, rank, batch_idx,
192 transpose_a_, adjoint_a_);
193 auto b_ref = GetSparseMatrixRef(*input_matrix_b, rank, batch_idx,
194 transpose_b_, adjoint_b_);
195
196 // Matrix multiply while *not* pruning numerical zeros on the fly.
197 // Allocates output SparseMatrix and moves it to our list of
198 // output_matrices.
199 output_matrices[batch_idx] = a_ref * b_ref;
200
201 // For now, batch_ptr contains the number of nonzeros in each
202 // batch.
203 batch_ptr_vec(batch_idx + 1) =
204 output_matrices[batch_idx].nonZeros();
205 }
206 });
207
208 // Compute the cumulative sum to obtain the batch pointers.
209 std::partial_sum(batch_ptr_vec.data(),
210 batch_ptr_vec.data() + batch_size + 1,
211 batch_ptr_vec.data());
212 const int64 total_nnz = batch_ptr_vec(batch_size);
213
214 // Allocate output tensors.
215 Tensor output_row_ptr(cpu_allocator(), DT_INT32,
216 TensorShape({(num_output_rows + 1) * batch_size}));
217 Tensor output_col_ind(cpu_allocator(), DT_INT32, TensorShape({total_nnz}));
218 Tensor output_values(cpu_allocator(), DataTypeToEnum<T>::value,
219 TensorShape({total_nnz}));
220 auto output_row_ptr_ptr = output_row_ptr.flat<int32>().data();
221 auto output_col_ind_ptr = output_col_ind.flat<int32>().data();
222 auto output_values_ptr = output_values.flat<T>().data();
223
224 // Copy the output matrices from each batch into the CSRSparseMatrix
225 // tensors.
226 Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
227 (3 * total_nnz) / batch_size /* cost per unit */,
228 [&](int64 batch_begin, int64 batch_end) {
229 for (int64 batch_idx = batch_begin; batch_idx < batch_end;
230 ++batch_idx) {
231 const SparseMatrix& output_matrix = output_matrices[batch_idx];
232 const int64 nnz = output_matrix.nonZeros();
233 std::copy(output_matrix.outerIndexPtr(),
234 output_matrix.outerIndexPtr() + num_output_rows + 1,
235 output_row_ptr_ptr + batch_idx * (num_output_rows + 1));
236 std::copy(output_matrix.innerIndexPtr(),
237 output_matrix.innerIndexPtr() + nnz,
238 output_col_ind_ptr + batch_ptr_vec(batch_idx));
239 std::copy(output_matrix.valuePtr(),
240 output_matrix.valuePtr() + nnz,
241 output_values_ptr + batch_ptr_vec(batch_idx));
242 }
243 });
244
245 // Create the CSRSparseMatrix object from its component Tensors and prepare
246 // the Variant output Tensor.
247 CSRSparseMatrix output_csr_matrix;
248 OP_REQUIRES_OK(ctx, CSRSparseMatrix::CreateCSRSparseMatrix(
249 DataTypeToEnum<T>::value, output_shape, batch_ptr,
250 output_row_ptr, output_col_ind, output_values,
251 &output_csr_matrix));
252 Tensor* output_csr_matrix_tensor;
253 AllocatorAttributes cpu_alloc;
254 cpu_alloc.set_on_host(true);
255 OP_REQUIRES_OK(
256 ctx, ctx->allocate_output(0, TensorShape({}), &output_csr_matrix_tensor,
257 cpu_alloc));
258 output_csr_matrix_tensor->scalar<Variant>()() =
259 std::move(output_csr_matrix);
260 }
261
262 private:
263 // Returns an Eigen::Ref expression of a SparseMatrix; which points to the
264 // underlying memory of the given CSRSparseMatrix.
GetSparseMatrixRef(const CSRSparseMatrix & csr_matrix,const int rank,const int batch_index,const bool transpose,const bool adjoint)265 Eigen::Ref<const SparseMatrix> GetSparseMatrixRef(
266 const CSRSparseMatrix& csr_matrix, const int rank, const int batch_index,
267 const bool transpose, const bool adjoint) {
268 const auto dense_shape = csr_matrix.dense_shape().vec<int64>();
269 const int64 num_rows = dense_shape(rank == 2 ? 0 : 1);
270 const int64 num_cols = dense_shape(rank == 2 ? 1 : 2);
271
272 Eigen::Map<const SparseMatrix> sparse_matrix(
273 num_rows, num_cols, csr_matrix.nnz(batch_index),
274 csr_matrix.row_pointers_vec(batch_index).data(),
275 csr_matrix.col_indices_vec(batch_index).data(),
276 csr_matrix.values_vec<T>(batch_index).data());
277
278 // The transpose/adjoint expressions are not actually evaluated until
279 // necessary. Hence we don't create copies or modify the input matrix
280 // inplace.
281 if (transpose) return sparse_matrix.transpose();
282 if (adjoint) return sparse_matrix.adjoint();
283 return sparse_matrix;
284 }
285
286 bool transpose_a_;
287 bool transpose_b_;
288 bool adjoint_a_;
289 bool adjoint_b_;
290 };
291
292 template <typename Device, typename T>
293 class CSRSparseMatMulGPUOp : public OpKernel {
294 public:
CSRSparseMatMulGPUOp(OpKernelConstruction * c)295 explicit CSRSparseMatMulGPUOp(OpKernelConstruction* c) : OpKernel(c) {
296 OP_REQUIRES_OK(c, c->GetAttr("transpose_a", &transpose_a_));
297 OP_REQUIRES_OK(c, c->GetAttr("transpose_b", &transpose_b_));
298 bool adjoint_a;
299 OP_REQUIRES_OK(c, c->GetAttr("adjoint_a", &adjoint_a));
300 OP_REQUIRES(c, !(adjoint_a && transpose_a_),
301 errors::InvalidArgument(
302 "Only one of adjoint_a and transpose_a may be true."));
303 bool adjoint_b;
304 OP_REQUIRES_OK(c, c->GetAttr("adjoint_b", &adjoint_b));
305 OP_REQUIRES(c, !(adjoint_b && transpose_b_),
306 errors::InvalidArgument(
307 "Only one of adjoint_b and transpose_b may be true."));
308 conjugate_a_ = adjoint_a;
309 conjugate_b_ = adjoint_b;
310 transpose_a_ = transpose_a_ || adjoint_a;
311 transpose_b_ = transpose_b_ || adjoint_b;
312 }
313
Compute(OpKernelContext * ctx)314 void Compute(OpKernelContext* ctx) final {
315 const CSRSparseMatrix* a_matrix;
316 const CSRSparseMatrix* b_matrix;
317 OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &a_matrix));
318 OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 1, &b_matrix));
319 OP_REQUIRES(
320 ctx, a_matrix->dtype() == DataTypeToEnum<T>::value,
321 errors::InvalidArgument("dtype of a is not equal to 'type': ",
322 DataTypeString(a_matrix->dtype()), " vs. ",
323 DataTypeString(DataTypeToEnum<T>::value)));
324 OP_REQUIRES(
325 ctx, b_matrix->dtype() == DataTypeToEnum<T>::value,
326 errors::InvalidArgument("dtype of b is not equal to 'type': ",
327 DataTypeString(b_matrix->dtype()), " vs. ",
328 DataTypeString(DataTypeToEnum<T>::value)));
329
330 // TODO(ebrevdo): MatMul support for broadcasting at least in the
331 // batch dimension.
332 auto a_dense_shape = a_matrix->dense_shape().vec<int64>();
333 auto b_dense_shape = b_matrix->dense_shape().vec<int64>();
334
335 TensorShape a_tensor_shape;
336 TensorShape b_tensor_shape;
337 OP_REQUIRES_OK(ctx,
338 TensorShapeUtils::MakeShape(a_dense_shape, &a_tensor_shape));
339 OP_REQUIRES_OK(ctx,
340 TensorShapeUtils::MakeShape(b_dense_shape, &b_tensor_shape));
341
342 const int rank = a_tensor_shape.dims();
343 const int row_dim = (rank == 2) ? 0 : 1;
344
345 const int64 a_inner_dim =
346 a_tensor_shape.dim_size(transpose_a_ ? row_dim : row_dim + 1);
347 const int64 b_inner_dim =
348 b_tensor_shape.dim_size(transpose_b_ ? row_dim + 1 : row_dim);
349
350 const int batch_size = a_matrix->batch_size();
351
352 OP_REQUIRES(
353 ctx, a_inner_dim == b_inner_dim,
354 errors::InvalidArgument(
355 "Inner product dimensions of A and B do not agree. Shapes are: ",
356 a_tensor_shape.DebugString(), " vs. ",
357 b_tensor_shape.DebugString()));
358
359 Tensor c_dense_shape_t(cpu_allocator(), DT_INT64, TensorShape({rank}));
360 auto c_dense_shape = c_dense_shape_t.vec<int64>();
361
362 if (rank == 3) c_dense_shape(0) = batch_size;
363 c_dense_shape(row_dim) =
364 a_tensor_shape.dim_size(transpose_a_ ? row_dim + 1 : row_dim);
365 c_dense_shape(row_dim + 1) =
366 b_tensor_shape.dim_size(transpose_b_ ? row_dim : row_dim + 1);
367
368 const int64 rows = c_dense_shape((rank == 2) ? 0 : 1);
369
370 CSRSparseMatrix c;
371 Tensor c_row_ptrs;
372
373 // TODO(ebrevdo): Re-enable transposing within the GEMM kernel when cuSparse
374 // stops spitting out CUSPARSE_STATUS_INTERNAL_ERROR values for transposes.
375 functor::CSRSparseSparseMatrixMatMul<Device, T> csr_gemm(
376 ctx, /*transpose_a=*/false, /*adjoint_a=*/false, /*transpose_b=*/false);
377 OP_REQUIRES_OK(ctx, csr_gemm.Initialize());
378
379 Tensor c_batch_ptr_t(cpu_allocator(), DT_INT32,
380 TensorShape({batch_size + 1}));
381 auto c_batch_ptr = c_batch_ptr_t.vec<int32>();
382 c_batch_ptr(0) = 0;
383
384 Tensor c_row_ptr_t;
385 OP_REQUIRES_OK(ctx, ctx->allocate_temp(
386 DT_INT32, TensorShape({batch_size * (rows + 1)}),
387 &c_row_ptr_t));
388 auto c_row_ptr = c_row_ptr_t.vec<int32>();
389
390 // Possibly transpose a.
391 const CSRSparseMatrix* a_input_matrix;
392 // If we need to transpose a, we will store the result temporarily
393 // in the object below.
394 CSRSparseMatrix a_matrix_transposed;
395 if (!transpose_a_) {
396 a_input_matrix = a_matrix;
397 } else {
398 functor::CSRSparseMatrixTranspose<Device, T> transpose;
399 OP_REQUIRES_OK(
400 ctx, transpose(ctx, conjugate_a_, *a_matrix, &a_matrix_transposed));
401 a_input_matrix = &a_matrix_transposed;
402 }
403 auto a_input_dense_shape = a_input_matrix->dense_shape().vec<int64>();
404
405 // Possibly transpose b.
406 const CSRSparseMatrix* b_input_matrix;
407 // If we need to transpose a, we will store the result temporarily
408 // in the object below.
409 CSRSparseMatrix b_matrix_transposed;
410 if (!transpose_b_) {
411 b_input_matrix = b_matrix;
412 } else {
413 functor::CSRSparseMatrixTranspose<Device, T> transpose;
414 OP_REQUIRES_OK(
415 ctx, transpose(ctx, conjugate_b_, *b_matrix, &b_matrix_transposed));
416 b_input_matrix = &b_matrix_transposed;
417 }
418 auto b_input_dense_shape = b_input_matrix->dense_shape().vec<int64>();
419
420 #if GOOGLE_CUDA && (CUDA_VERSION >= 10000)
421 size_t maxWorkspaceSize = 0;
422 for (int i = 0; i < batch_size; ++i) {
423 // Calculate maximum workspace size over batch.
424 ConstCSRComponent<T> a_comp{a_input_matrix->row_pointers_vec(i),
425 a_input_matrix->col_indices_vec(i),
426 a_input_matrix->values_vec<T>(i),
427 a_input_dense_shape};
428 ConstCSRComponent<T> b_comp{b_input_matrix->row_pointers_vec(i),
429 b_input_matrix->col_indices_vec(i),
430 b_input_matrix->values_vec<T>(i),
431 b_input_dense_shape};
432 size_t thisWorkspaceSize;
433 OP_REQUIRES_OK(
434 ctx, csr_gemm.GetWorkspaceSize(a_comp, b_comp, &thisWorkspaceSize));
435 if (thisWorkspaceSize > maxWorkspaceSize) {
436 maxWorkspaceSize = thisWorkspaceSize;
437 }
438 }
439
440 Tensor temp;
441 OP_REQUIRES_OK(
442 ctx, ctx->allocate_temp(
443 DT_INT8, TensorShape({static_cast<int64>(maxWorkspaceSize)}),
444 &temp));
445 void* workspace = temp.flat<int8>().data();
446 #else
447 void* workspace = nullptr;
448 #endif
449
450 for (int i = 0; i < batch_size; ++i) {
451 // Calculate output sizes for all minibatch entries.
452 // Store in c_batch_ptr and update c_row_ptrs.
453 ConstCSRComponent<T> a_comp{a_input_matrix->row_pointers_vec(i),
454 a_input_matrix->col_indices_vec(i),
455 a_input_matrix->values_vec<T>(i),
456 a_input_dense_shape};
457 ConstCSRComponent<T> b_comp{b_input_matrix->row_pointers_vec(i),
458 b_input_matrix->col_indices_vec(i),
459 b_input_matrix->values_vec<T>(i),
460 b_input_dense_shape};
461
462 TTypes<int32>::UnalignedVec c_row_ptr_i(&c_row_ptr(i * (rows + 1)),
463 rows + 1);
464
465 int c_nnz_i;
466 OP_REQUIRES_OK(ctx,
467 csr_gemm.GetOutputStructure(a_comp, b_comp, c_row_ptr_i,
468 &c_nnz_i, workspace));
469 c_batch_ptr(i + 1) = c_batch_ptr(i) + c_nnz_i;
470 }
471
472 Tensor c_col_ind_t;
473 Tensor c_values_t;
474
475 const int total_nnz = c_batch_ptr(batch_size);
476
477 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_INT32, TensorShape({total_nnz}),
478 &c_col_ind_t));
479 OP_REQUIRES_OK(ctx,
480 ctx->allocate_temp(DataTypeToEnum<T>::value,
481 TensorShape({total_nnz}), &c_values_t));
482 OP_REQUIRES_OK(ctx,
483 CSRSparseMatrix::CreateCSRSparseMatrix(
484 DataTypeToEnum<T>::value, c_dense_shape_t, c_batch_ptr_t,
485 c_row_ptr_t, c_col_ind_t, c_values_t, &c));
486
487 for (int i = 0; i < batch_size; ++i) {
488 ConstCSRComponent<T> a_comp{a_input_matrix->row_pointers_vec(i),
489 a_input_matrix->col_indices_vec(i),
490 a_input_matrix->values_vec<T>(i),
491 a_input_dense_shape};
492 ConstCSRComponent<T> b_comp{b_input_matrix->row_pointers_vec(i),
493 b_input_matrix->col_indices_vec(i),
494 b_input_matrix->values_vec<T>(i),
495 b_input_dense_shape};
496 CSRComponent<T> c_comp{c.row_pointers_vec(i), c.col_indices_vec(i),
497 c.values_vec<T>(i), c_dense_shape};
498 OP_REQUIRES_OK(ctx, csr_gemm.Compute(a_comp, b_comp, &c_comp, workspace));
499 }
500
501 Tensor c_t(cpu_allocator(), DT_VARIANT, TensorShape({}));
502 c_t.scalar<Variant>()() = std::move(c);
503 ctx->set_output(0, c_t);
504 }
505
506 private:
507 bool transpose_a_;
508 bool transpose_b_;
509 bool conjugate_a_;
510 bool conjugate_b_;
511 };
512
513 #define REGISTER_CPU(T) \
514 REGISTER_KERNEL_BUILDER(Name("SparseMatrixSparseMatMul") \
515 .Device(DEVICE_CPU) \
516 .TypeConstraint<T>("type"), \
517 CSRSparseMatMulCPUOp<T>);
518
519 REGISTER_CPU(float)
520 REGISTER_CPU(double)
521 REGISTER_CPU(complex64)
522 REGISTER_CPU(complex128)
523
524 #undef REGISTER_CPU
525
526 #define REGISTER(DEV, T) \
527 REGISTER_KERNEL_BUILDER(Name("SparseMatrixSparseMatMul") \
528 .Device(DEVICE_##DEV) \
529 .TypeConstraint<T>("type"), \
530 CSRSparseMatMulGPUOp<DEV##Device, T>);
531
532 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
533
534 #define REGISTER_GPU(T) REGISTER(GPU, T)
535
536 REGISTER_GPU(float)
537 REGISTER_GPU(double)
538 #if GOOGLE_CUDA
539 REGISTER_GPU(complex64)
540 REGISTER_GPU(complex128)
541 #endif // GOOGLE_CUDA
542
543 #undef REGISTER_GPU
544
545 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
546
547 #undef REGISTER
548
549 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
550 namespace functor {
551 template <typename T>
552 struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
553 : public CSRStructureModifyingFunctor<GPUDevice, T> {
CSRSparseSparseMatrixMatMultensorflow::functor::CSRSparseSparseMatrixMatMul554 explicit CSRSparseSparseMatrixMatMul(OpKernelContext* ctx, bool transpose_a,
555 bool adjoint_a, bool transpose_b)
556 : ctx_(ctx),
557 cuda_sparse_(ctx),
558 initialized_(false),
559 transpose_a_(transpose_a),
560 adjoint_a_(adjoint_a),
561 #if (GOOGLE_CUDA && (CUDA_VERSION < 10000)) || TENSORFLOW_USE_ROCM
562 transpose_b_(transpose_b) {
563 #else
564 transpose_b_(transpose_b),
565 info_(nullptr) {
566 #endif // CUDA_VERSION < 10000
567 // TODO(ebrevdo): Figure out why transposed implementations crash cuSparse.
568 transA_ = transpose_a
569 ? (adjoint_a ? GPUSPARSE(OPERATION_TRANSPOSE)
570 : GPUSPARSE(OPERATION_CONJUGATE_TRANSPOSE))
571 : GPUSPARSE(OPERATION_NON_TRANSPOSE);
572 transB_ = transpose_b ? GPUSPARSE(OPERATION_TRANSPOSE)
573 : GPUSPARSE(OPERATION_NON_TRANSPOSE);
574 }
575
576 #if GOOGLE_CUDA && (CUDA_VERSION >= 10000)
577 ~CSRSparseSparseMatrixMatMul() {
578 if (initialized_) {
579 cusparseDestroyCsrgemm2Info(info_);
580 }
581 }
582 #endif
583
584 Status Initialize() {
585 if (adjoint_a_ && transpose_a_) {
586 return errors::InvalidArgument(
587 "Only one of adjoint_a and transpose_a may be true.");
588 }
589
590 TF_RETURN_IF_ERROR(cuda_sparse_.Initialize());
591 TF_RETURN_IF_ERROR(descrA_.Initialize());
592 TF_RETURN_IF_ERROR(descrB_.Initialize());
593 TF_RETURN_IF_ERROR(descrC_.Initialize());
594 #if GOOGLE_CUDA && (CUDA_VERSION >= 10000)
595 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsrgemm2Info(&info_));
596 #endif
597 initialized_ = true;
598 return Status::OK();
599 }
600
601 Status GetWorkspaceSize(const ConstCSRComponent<T>& a,
602 const ConstCSRComponent<T>& b, size_t* bufferSize) {
603 #if GOOGLE_CUDA && (CUDA_VERSION >= 10000)
604 DCHECK(initialized_);
605 const int m =
606 a.dense_shape_host(a.dense_shape_host.size() - (transpose_a_ ? 1 : 2));
607 if (!transpose_a_) {
608 DCHECK_EQ(m, a.row_ptr.size() - 1);
609 }
610 const int k =
611 a.dense_shape_host(a.dense_shape_host.size() - (transpose_a_ ? 2 : 1));
612 if (!transpose_b_) {
613 DCHECK_EQ(k, b.row_ptr.size() - 1);
614 }
615 const int nnzA = a.col_ind.size();
616 const int nnzB = b.col_ind.size();
617
618 const int n =
619 b.dense_shape_host(b.dense_shape_host.size() - (transpose_b_ ? 2 : 1));
620
621 TF_RETURN_IF_ERROR(cuda_sparse_.CsrgemmBufferSize<T>(
622 m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(), a.col_ind.data(),
623 descrB_.descr(), nnzB, b.row_ptr.data(), b.col_ind.data(), info_,
624 bufferSize));
625 #endif
626
627 return Status::OK();
628 }
629
630 Status GetOutputStructure(const ConstCSRComponent<T>& a,
631 const ConstCSRComponent<T>& b,
632 TTypes<int32>::UnalignedVec c_row_ptr,
633 int* output_nnz, void* workspace) {
634 DCHECK(initialized_);
635
636 const int m =
637 a.dense_shape_host(a.dense_shape_host.size() - (transpose_a_ ? 1 : 2));
638 if (!transpose_a_) {
639 DCHECK_EQ(m, a.row_ptr.size() - 1);
640 }
641 DCHECK_EQ(m, c_row_ptr.size() - 1);
642 const int k =
643 a.dense_shape_host(a.dense_shape_host.size() - (transpose_a_ ? 2 : 1));
644 if (!transpose_b_) {
645 DCHECK_EQ(k, b.row_ptr.size() - 1);
646 }
647 const int nnzA = a.col_ind.size();
648 const int nnzB = b.col_ind.size();
649
650 const int n =
651 b.dense_shape_host(b.dense_shape_host.size() - (transpose_b_ ? 2 : 1));
652
653 *output_nnz = -1;
654
655 #if (GOOGLE_CUDA && (CUDA_VERSION < 10000)) || TENSORFLOW_USE_ROCM
656 TF_RETURN_IF_ERROR(cuda_sparse_.CsrgemmNnz(
657 transA_, transB_, m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(),
658 a.col_ind.data(), descrB_.descr(), nnzB, b.row_ptr.data(),
659 b.col_ind.data(), descrC_.descr(), c_row_ptr.data(), output_nnz));
660 #else
661 TF_RETURN_IF_ERROR(cuda_sparse_.CsrgemmNnz(
662 m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(), a.col_ind.data(),
663 descrB_.descr(), nnzB, b.row_ptr.data(), b.col_ind.data(),
664 descrC_.descr(), c_row_ptr.data(), output_nnz, info_, workspace));
665 #endif
666
667 if (*output_nnz < 0) {
668 return errors::Internal(
669 "CSRMatMul: CsrgemmNnz returned nnzTotalDevHostPtr < 0: ",
670 *output_nnz);
671 }
672 return Status::OK();
673 }
674
675 Status Compute(const ConstCSRComponent<T>& a, const ConstCSRComponent<T>& b,
676 CSRComponent<T>* c, void* workspace) {
677 DCHECK(initialized_);
678
679 const int m =
680 a.dense_shape_host(a.dense_shape_host.size() - (transpose_a_ ? 1 : 2));
681 if (!transpose_a_) {
682 DCHECK_EQ(m, a.row_ptr.size() - 1);
683 }
684 DCHECK_EQ(m, c->dense_shape_host(c->dense_shape_host.size() - 2));
685 DCHECK_EQ(m, c->row_ptr.size() - 1);
686 const int k =
687 a.dense_shape_host(a.dense_shape_host.size() - (transpose_a_ ? 2 : 1));
688 if (!transpose_b_) {
689 DCHECK_EQ(k, b.row_ptr.size() - 1);
690 }
691 const int nnzA = a.col_ind.size();
692 const int nnzB = b.col_ind.size();
693
694 const int n =
695 b.dense_shape_host(b.dense_shape_host.size() - (transpose_b_ ? 2 : 1));
696 DCHECK_EQ(n, c->dense_shape_host(c->dense_shape_host.size() - 1));
697
698 #if (GOOGLE_CUDA && (CUDA_VERSION < 10000)) || TENSORFLOW_USE_ROCM
699 TF_RETURN_IF_ERROR(cuda_sparse_.Csrgemm(
700 transA_, transB_, m, k, n, descrA_.descr(), nnzA, a.values.data(),
701 a.row_ptr.data(), a.col_ind.data(), descrB_.descr(), nnzB,
702 b.values.data(), b.row_ptr.data(), b.col_ind.data(), descrC_.descr(),
703 c->values.data(), c->row_ptr.data(), c->col_ind.data()));
704 #else
705 TF_RETURN_IF_ERROR(cuda_sparse_.Csrgemm(
706 m, n, k, descrA_.descr(), nnzA, a.values.data(), a.row_ptr.data(),
707 a.col_ind.data(), descrB_.descr(), nnzB, b.values.data(),
708 b.row_ptr.data(), b.col_ind.data(), descrC_.descr(), c->values.data(),
709 c->row_ptr.data(), c->col_ind.data(), info_, workspace));
710 #endif
711
712 // TODO(ebrevdo): Add a flag to CSRSparseMatrix whether matrix
713 // columns are sorted? Above operation leads to unsorted columns.
714 // For now, here is an example of how to ensure the output columns
715 // are sorted. Leaving here in case we realize we need to ensure
716 // sorted columns in the future.
717 //
718 // TF_RETURN_IF_ERROR(cuda_sparse.Csru2csr(
719 // m, n, nnzTotalDevHostPtr, descrA_.descr(), c->values.data(),
720 // c->row_ptr.data(), c->col_ind.data()));
721
722 return Status::OK();
723 }
724
725 private:
726 OpKernelContext* ctx_;
727 GpuSparse cuda_sparse_;
728 bool initialized_;
729 bool transpose_a_;
730 bool adjoint_a_;
731 bool transpose_b_;
732 GpuSparseMatrixDescriptor descrA_;
733 GpuSparseMatrixDescriptor descrB_;
734 GpuSparseMatrixDescriptor descrC_;
735 gpusparseOperation_t transA_;
736 gpusparseOperation_t transB_;
737 #if GOOGLE_CUDA && (CUDA_VERSION >= 10000)
738 csrgemm2Info_t info_;
739 #endif
740 };
741
742 } // namespace functor
743
744 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
745
746 } // namespace tensorflow
747