#include #include #include #include #include #include #ifndef AT_PER_OPERATOR_HEADERS #include #else #include #include #endif namespace at::native { DEFINE_DISPATCH(flatten_indices_stub); } // namespace at::native namespace at::sparse { // NOTE [ Flatten Sparse Indices ] // This helper function flattens a sparse indices tensor (a Tensor) into a 1D // indices tensor. E.g., // input = [[2, 4, 0], // [3, 1, 10]] // full_size = [2, 12] // output = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 10 ] = [27, 49, 10] // // In other words, assuming that each `indices[i, :]` is a valid index to a // tensor `t` of shape `full_size`. This returns the corresponding indices to // the flattened tensor `t.reshape( prod(full_size[:indices.size(0)]), -1 )`. // if forceClone is true, the result will forced to be a clone of self. // if force_clone is true, the result will forced to be a clone of self. Tensor flatten_indices(const Tensor& indices, IntArrayRef full_size, bool force_clone /*= false*/) { int64_t sparse_dim = indices.size(0); if (sparse_dim == 1) { if (force_clone) { return indices.squeeze(0).clone(at::MemoryFormat::Contiguous); } else { return indices.squeeze(0); } } else { if (!indices.numel()) { return at::zeros({indices.size(1)}, indices.options().dtype(kLong)); } return at::native::flatten_indices_stub(indices.device().type(), indices, full_size.slice(0, sparse_dim)); } } // Flatten sparse tensor's indices from nD to 1D, similar to NOTE [ Flatten Sparse Indices ], // except this one allows partial flatten: only flatten on specified dims. Note that // the flatten indices might be uncoalesced if dims_to_flatten.size() < sparse_dim. // Also if input indices is already coalesced, the flattened indices will also be sorted. // // args: // indices: sparse tensor indices // sizes: sparse tensor sizes // dims_to_flatten: a list of dim index to flatten // // Ex1: // indices = [[2, 4, 0], // [3, 1, 3]] // sizes = [2, 12] // dims_to_flatten = [0, 1] // new_indices = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 3 ] = [27, 49, 3] // // Ex2: // dims_to_flatten = [1] // new_indices = [ 3, 1, 3 ] # uncoalesced Tensor flatten_indices_by_dims(const Tensor& indices, const IntArrayRef& sizes, const IntArrayRef& dims_to_flatten){ Tensor new_indices = at::zeros({indices.size(1)}, indices.options()); for (auto d : dims_to_flatten) { new_indices.mul_(sizes[d]); new_indices.add_(indices.select(0, d)); } return new_indices; } Tensor coo_to_csr(const int64_t* indices, int64_t dim, int64_t nnz) { /* Find the CSR representation for a row `indices` from the COO format Inputs: `indices` is the row pointer from COO indices `dim` is the row dimensionality `nnz` is the number of non-zeros Output: `csr` is a compressed row array in a CSR format */ Tensor csr = at::zeros({dim + 1}, kLong); // TODO: eliminate this conditional when zero-size dims supported correctly if (nnz > 0) { auto csr_accessor = csr.accessor(); // Convert the sparse matrix to CSR format at::parallel_for(0, nnz, 10000, [&](int64_t start, int64_t end) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) int64_t h, hp0, hp1; for (const auto i : c10::irange(start, end)) { hp0 = indices[i]; hp1 = (i+1 == nnz) ? dim : indices[i+1]; if (hp0 != hp1) { for (h = hp0; h < hp1; h++) { csr_accessor[h+1] = i+1; } } } }); } return csr; } Tensor zeros_like_with_indices(const Tensor& t) { TORCH_INTERNAL_ASSERT(t.is_sparse()); return at::_sparse_coo_tensor_with_dims_and_tensors( t.sparse_dim(), t.dense_dim(), t.sizes(), t._indices().clone(), at::zeros({1}, t._values().options()).expand_as(t._values()), t.options(), t.is_coalesced()); } Tensor full_coo_indices(IntArrayRef sizes, TensorOptions options) { const auto max_size = *std::max_element(sizes.begin(), sizes.end()); const auto max_size_arange = at::arange(max_size, options); std::vector stack; stack.reserve(sizes.size()); for (size_t i=0; i < sizes.size(); i++) { Tensor a = max_size_arange.narrow(-1, 0, sizes[i]); for (size_t j=0; j < sizes.size(); j++) { if (i != j) { a.unsqueeze_(j); } } stack.push_back(a.expand(sizes)); } return at::stack(stack).flatten(1, -1); } } // namespace at::sparse