#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include #include #include #include #include #include #include #include #include #include #ifndef AT_PER_OPERATOR_HEADERS #include #include #else #include #include #include #endif namespace at::native { using namespace at::sparse; /* This is an implementation of the SMMP algorithm: "Sparse Matrix Multiplication Package (SMMP)" Randolph E. Bank and Craig C. Douglas https://doi.org/10.1007/BF02070824 */ namespace { // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) void csr_to_coo(const int64_t n_row, const int64_t Ap[], int64_t Bi[]) { /* Expands a compressed row pointer into a row indices array Inputs: `n_row` is the number of rows in `Ap` `Ap` is the row pointer Output: `Bi` is the row indices */ for (const auto i : c10::irange(n_row)) { for (int64_t jj = Ap[i]; jj < Ap[i + 1]; jj++) { Bi[jj] = i; } } } template int64_t _csr_matmult_maxnnz( const int64_t n_row, const int64_t n_col, const index_t_ptr Ap, const index_t_ptr Aj, const index_t_ptr Bp, const index_t_ptr Bj) { /* Compute needed buffer size for matrix `C` in `C = A@B` operation. The matrices should be in proper CSR structure, and their dimensions should be compatible. */ std::vector mask(n_col, -1); int64_t nnz = 0; for (const auto i : c10::irange(n_row)) { int64_t row_nnz = 0; for (int64_t jj = Ap[i]; jj < Ap[i + 1]; jj++) { int64_t j = Aj[jj]; for (int64_t kk = Bp[j]; kk < Bp[j + 1]; kk++) { int64_t k = Bj[kk]; if (mask[k] != i) { mask[k] = i; row_nnz++; } } } int64_t next_nnz = nnz + row_nnz; nnz = next_nnz; } return nnz; } template void _csr_matmult( const int64_t n_row, const int64_t n_col, const index_t_ptr Ap, const index_t_ptr Aj, const scalar_t_ptr Ax, const index_t_ptr Bp, const index_t_ptr Bj, const scalar_t_ptr Bx, // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) typename index_t_ptr::value_type Cp[], // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) typename index_t_ptr::value_type Cj[], // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) typename scalar_t_ptr::value_type Cx[]) { /* Compute CSR entries for matrix C = A@B. The matrices `A` and 'B' should be in proper CSR structure, and their dimensions should be compatible. Inputs: `n_row` - number of row in A `n_col` - number of columns in B `Ap[n_row+1]` - row pointer `Aj[nnz(A)]` - column indices `Ax[nnz(A)] - nonzeros `Bp[?]` - row pointer `Bj[nnz(B)]` - column indices `Bx[nnz(B)]` - nonzeros Outputs: `Cp[n_row+1]` - row pointer `Cj[nnz(C)]` - column indices `Cx[nnz(C)]` - nonzeros Note: Output arrays Cp, Cj, and Cx must be preallocated */ using index_t = typename index_t_ptr::value_type; using scalar_t = typename scalar_t_ptr::value_type; std::vector next(n_col, -1); std::vector sums(n_col, 0); int64_t nnz = 0; Cp[0] = 0; for (const auto i : c10::irange(n_row)) { index_t head = -2; index_t length = 0; index_t jj_start = Ap[i]; index_t jj_end = Ap[i + 1]; for (const auto jj : c10::irange(jj_start, jj_end)) { index_t j = Aj[jj]; scalar_t v = Ax[jj]; index_t kk_start = Bp[j]; index_t kk_end = Bp[j + 1]; for (const auto kk : c10::irange(kk_start, kk_end)) { index_t k = Bj[kk]; sums[k] += v * Bx[kk]; if (next[k] == -1) { next[k] = head; head = k; length++; } } } for (C10_UNUSED const auto jj : c10::irange(length)) { // NOTE: the linked list that encodes col indices // is not guaranteed to be sorted. Cj[nnz] = head; Cx[nnz] = sums[head]; nnz++; index_t temp = head; head = next[head]; next[temp] = -1; // clear arrays sums[temp] = 0; } // Make sure that col indices are sorted. // TODO: a better approach is to implement a CSR @ CSC kernel. // NOTE: Cx arrays are expected to be contiguous! auto col_indices_accessor = StridedRandomAccessor(Cj + nnz - length, 1); auto val_accessor = StridedRandomAccessor(Cx + nnz - length, 1); auto kv_accessor = CompositeRandomAccessorCPU< decltype(col_indices_accessor), decltype(val_accessor) >(col_indices_accessor, val_accessor); std::sort(kv_accessor, kv_accessor + length, [](const auto& lhs, const auto& rhs) -> bool { return get<0>(lhs) < get<0>(rhs); }); Cp[i + 1] = nnz; } } template void sparse_matmul_kernel( Tensor& output, const Tensor& mat1, const Tensor& mat2) { /* Computes the sparse-sparse matrix multiplication between `mat1` and `mat2`, which are sparse tensors in COO format. */ auto M = mat1.size(0); auto N = mat2.size(1); const auto mat1_csr = mat1.to_sparse_csr(); const auto mat2_csr = mat2.to_sparse_csr(); auto mat1_crow_indices_ptr = StridedRandomAccessor( mat1_csr.crow_indices().data_ptr(), mat1_csr.crow_indices().stride(-1)); auto mat1_col_indices_ptr = StridedRandomAccessor( mat1_csr.col_indices().data_ptr(), mat1_csr.col_indices().stride(-1)); auto mat1_values_ptr = StridedRandomAccessor( mat1_csr.values().data_ptr(), mat1_csr.values().stride(-1)); auto mat2_crow_indices_ptr = StridedRandomAccessor( mat2_csr.crow_indices().data_ptr(), mat2_csr.crow_indices().stride(-1)); auto mat2_col_indices_ptr = StridedRandomAccessor( mat2_csr.col_indices().data_ptr(), mat2_csr.col_indices().stride(-1)); auto mat2_values_ptr = StridedRandomAccessor( mat2_csr.values().data_ptr(), mat2_csr.values().stride(-1)); const auto nnz = _csr_matmult_maxnnz( M, N, mat1_crow_indices_ptr, mat1_col_indices_ptr, mat2_crow_indices_ptr, mat2_col_indices_ptr); auto output_indices = output._indices(); auto output_values = output._values(); Tensor output_indptr = at::empty({M + 1}, kLong); at::native::resize_output(output_indices, {2, nnz}); at::native::resize_output(output_values, nnz); Tensor output_row_indices = output_indices.select(0, 0); Tensor output_col_indices = output_indices.select(0, 1); // TODO: replace with a CSR @ CSC kernel for better performance. _csr_matmult( M, N, mat1_crow_indices_ptr, mat1_col_indices_ptr, mat1_values_ptr, mat2_crow_indices_ptr, mat2_col_indices_ptr, mat2_values_ptr, output_indptr.data_ptr(), output_col_indices.data_ptr(), output_values.data_ptr()); csr_to_coo(M, output_indptr.data_ptr(), output_row_indices.data_ptr()); output._coalesced_(true); } } // end anonymous namespace Tensor sparse_sparse_matmul_cpu(const Tensor& mat1_, const Tensor& mat2_) { TORCH_INTERNAL_ASSERT(mat1_.is_sparse()); TORCH_INTERNAL_ASSERT(mat2_.is_sparse()); TORCH_CHECK(mat1_.dim() == 2); TORCH_CHECK(mat2_.dim() == 2); TORCH_CHECK(mat1_.dense_dim() == 0, "sparse_sparse_matmul_cpu: scalar values expected, got ", mat1_.dense_dim(), "D values"); TORCH_CHECK(mat2_.dense_dim() == 0, "sparse_sparse_matmul_cpu: scalar values expected, got ", mat2_.dense_dim(), "D values"); TORCH_CHECK( mat1_.size(1) == mat2_.size(0), "mat1 and mat2 shapes cannot be multiplied (", mat1_.size(0), "x", mat1_.size(1), " and ", mat2_.size(0), "x", mat2_.size(1), ")"); TORCH_CHECK(mat1_.scalar_type() == mat2_.scalar_type(), "mat1 dtype ", mat1_.scalar_type(), " does not match mat2 dtype ", mat2_.scalar_type()); auto output = at::native::empty_like(mat1_); output.sparse_resize_and_clear_({mat1_.size(0), mat2_.size(1)}, mat1_.sparse_dim(), 0); AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(mat1_.scalar_type(), "sparse_matmul", [&] { sparse_matmul_kernel(output, mat1_.coalesce(), mat2_.coalesce()); }); return output; } } // namespace at::native