1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <atomic> 17 #include <numeric> 18 #include <vector> 19 20 #include "tensorflow/core/framework/op_requires.h" 21 22 #define EIGEN_USE_THREADS 23 24 #include "third_party/eigen3/Eigen/Core" 25 #include "third_party/eigen3/Eigen/SparseCholesky" 26 #include "third_party/eigen3/Eigen/SparseCore" 27 #include "third_party/eigen3/Eigen/OrderingMethods" 28 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 29 #include "tensorflow/core/framework/op.h" 30 #include "tensorflow/core/framework/op_kernel.h" 31 #include "tensorflow/core/framework/tensor_types.h" 32 #include "tensorflow/core/framework/variant_op_registry.h" 33 #include "tensorflow/core/kernels/sparse/kernels.h" 34 #include "tensorflow/core/kernels/sparse/sparse_matrix.h" 35 #include "tensorflow/core/util/work_sharder.h" 36 37 namespace tensorflow { 38 39 // Op to compute the sparse Cholesky factorization of a sparse matrix. 40 // 41 // Implements a CPU kernel which returns the lower triangular sparse Cholesky 42 // factor of a CSRSparseMatrix, using the fill-in reducing permutation. 43 // 44 // The CSRSparseMatrix may represent a single sparse matrix (rank 2) or a batch 45 // of sparse matrices (rank 3). Each component must represent a symmetric 46 // positive definite (SPD) matrix. In particular, this means the component 47 // matrices must be square. We don't actually check if the input is symmetric, 48 // only the lower triangular part of each component is read. 49 // 50 // The associated permutation must be a Tensor of rank (R - 1), where the 51 // CSRSparseMatrix has rank R. Additionally, the batch dimension of the 52 // CSRSparseMatrix and the permutation must be the same. Each batch of 53 // the permutation should the contain each of the integers [0,..,N - 1] exactly 54 // once, where N is the number of rows of each CSR SparseMatrix component. 55 // TODO(anudhyan): Add checks to throw an InvalidArgument error if the 56 // permutation is not valid. 57 // 58 // Returns a CSRSparseMatrix representing the lower triangular (batched) 59 // Cholesky factors. It has the same shape as the input CSRSparseMatrix. For 60 // each component sparse matrix A, the corresponding output sparse matrix L 61 // satisfies the identity: 62 // A = L * Lt 63 // where Lt denotes the adjoint of L. 64 // 65 // TODO(b/126472741): Due to the multiple batches of a 3D CSRSparseMatrix being 66 // laid out in contiguous memory, this implementation allocates memory to store 67 // a temporary copy of the Cholesky factor. Consequently, it uses roughly twice 68 // the amount of memory that it needs to. This may cause a memory blowup for 69 // sparse matrices with a high number of non-zero elements. 70 template <typename T> 71 class CSRSparseCholeskyCPUOp : public OpKernel { 72 // Note: We operate in column major (CSC) format in this Op since the 73 // SimplicialLLT returns the factor in column major. 74 using SparseMatrix = Eigen::SparseMatrix<T, Eigen::ColMajor>; 75 76 public: CSRSparseCholeskyCPUOp(OpKernelConstruction * c)77 explicit CSRSparseCholeskyCPUOp(OpKernelConstruction* c) : OpKernel(c) {} 78 Compute(OpKernelContext * ctx)79 void Compute(OpKernelContext* ctx) final { 80 // Extract inputs and validate shapes and types. 81 const CSRSparseMatrix* input_matrix; 82 OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &input_matrix)); 83 const Tensor& input_permutation_indices = ctx->input(1); 84 85 int64_t num_rows; 86 int batch_size; 87 OP_REQUIRES_OK(ctx, ValidateInputs(*input_matrix, input_permutation_indices, 88 &batch_size, &num_rows)); 89 90 // Allocate batch pointers. 91 Tensor batch_ptr(cpu_allocator(), DT_INT32, TensorShape({batch_size + 1})); 92 auto batch_ptr_vec = batch_ptr.vec<int32>(); 93 batch_ptr_vec(0) = 0; 94 95 // Temporary vector of Eigen SparseMatrices to store the Sparse Cholesky 96 // factors. 97 // Note: we use column-compressed (CSC) SparseMatrix because SimplicialLLT 98 // returns the factors in column major format. Since our input should be 99 // symmetric, column major and row major is identical in storage. We just 100 // have to switch to reading the upper triangular part of the input, which 101 // corresponds to the lower triangular part in row major format. 102 std::vector<SparseMatrix> sparse_cholesky_factors(batch_size); 103 104 // TODO(anudhyan): Tune the cost per unit based on benchmarks. 105 const double nnz_per_row = 106 (input_matrix->total_nnz() / batch_size) / num_rows; 107 const int64_t sparse_cholesky_cost_per_batch = 108 nnz_per_row * nnz_per_row * num_rows; 109 // Perform sparse Cholesky factorization of each batch in parallel. 110 auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); 111 std::atomic<int64_t> invalid_input_index(-1); 112 Shard(worker_threads.num_threads, worker_threads.workers, batch_size, 113 sparse_cholesky_cost_per_batch, 114 [&](int64_t batch_begin, int64_t batch_end) { 115 for (int64_t batch_index = batch_begin; batch_index < batch_end; 116 ++batch_index) { 117 // Define an Eigen SparseMatrix Map to operate on the 118 // CSRSparseMatrix component without copying the data. 119 Eigen::Map<const SparseMatrix> sparse_matrix( 120 num_rows, num_rows, input_matrix->nnz(batch_index), 121 input_matrix->row_pointers_vec(batch_index).data(), 122 input_matrix->col_indices_vec(batch_index).data(), 123 input_matrix->values_vec<T>(batch_index).data()); 124 125 Eigen::SimplicialLLT<SparseMatrix, Eigen::Upper, 126 Eigen::NaturalOrdering<int>> 127 solver; 128 auto permutation_indices_flat = 129 input_permutation_indices.flat<int32>().data(); 130 131 // Invert the fill-in reducing ordering and apply it to the input 132 // sparse matrix. 133 Eigen::Map< 134 Eigen::PermutationMatrix<Eigen::Dynamic, Eigen::Dynamic, int>> 135 permutation(permutation_indices_flat + batch_index * num_rows, 136 num_rows); 137 auto permutation_inverse = permutation.inverse(); 138 139 SparseMatrix permuted_sparse_matrix; 140 permuted_sparse_matrix.template selfadjointView<Eigen::Upper>() = 141 sparse_matrix.template selfadjointView<Eigen::Upper>() 142 .twistedBy(permutation_inverse); 143 144 // Compute the Cholesky decomposition. 145 solver.compute(permuted_sparse_matrix); 146 if (solver.info() != Eigen::Success) { 147 invalid_input_index = batch_index; 148 return; 149 } 150 151 // Get the upper triangular factor, which would end up in the 152 // lower triangular part of the output CSRSparseMatrix when 153 // interpreted in row major format. 154 sparse_cholesky_factors[batch_index] = 155 std::move(solver.matrixU()); 156 // For now, batch_ptr contains the number of nonzeros in each 157 // batch. 158 batch_ptr_vec(batch_index + 1) = 159 sparse_cholesky_factors[batch_index].nonZeros(); 160 } 161 }); 162 163 // Check for invalid input. 164 OP_REQUIRES( 165 ctx, invalid_input_index == -1, 166 errors::InvalidArgument( 167 "Sparse Cholesky factorization failed for batch index ", 168 invalid_input_index.load(), ". The input might not be valid.")); 169 170 // Compute a cumulative sum to obtain the batch pointers. 171 std::partial_sum(batch_ptr_vec.data(), 172 batch_ptr_vec.data() + batch_size + 1, 173 batch_ptr_vec.data()); 174 175 // Allocate output Tensors. 176 const int64_t total_nnz = batch_ptr_vec(batch_size); 177 Tensor output_row_ptr(cpu_allocator(), DT_INT32, 178 TensorShape({(num_rows + 1) * batch_size})); 179 Tensor output_col_ind(cpu_allocator(), DT_INT32, TensorShape({total_nnz})); 180 Tensor output_values(cpu_allocator(), DataTypeToEnum<T>::value, 181 TensorShape({total_nnz})); 182 auto output_row_ptr_ptr = output_row_ptr.flat<int32>().data(); 183 auto output_col_ind_ptr = output_col_ind.flat<int32>().data(); 184 auto output_values_ptr = output_values.flat<T>().data(); 185 186 // Copy the output matrices from each batch into the CSRSparseMatrix 187 // Tensors. 188 // TODO(b/129906419): Factor out the copy from Eigen SparseMatrix to 189 // CSRSparseMatrix into common utils. This is also used in 190 // SparseMatrixSparseMatMul. 191 Shard(worker_threads.num_threads, worker_threads.workers, batch_size, 192 (3 * total_nnz) / batch_size /* cost per unit */, 193 [&](int64_t batch_begin, int64_t batch_end) { 194 for (int64_t batch_index = batch_begin; batch_index < batch_end; 195 ++batch_index) { 196 const SparseMatrix& cholesky_factor = 197 sparse_cholesky_factors[batch_index]; 198 const int64_t nnz = cholesky_factor.nonZeros(); 199 200 std::copy(cholesky_factor.outerIndexPtr(), 201 cholesky_factor.outerIndexPtr() + num_rows + 1, 202 output_row_ptr_ptr + batch_index * (num_rows + 1)); 203 std::copy(cholesky_factor.innerIndexPtr(), 204 cholesky_factor.innerIndexPtr() + nnz, 205 output_col_ind_ptr + batch_ptr_vec(batch_index)); 206 std::copy(cholesky_factor.valuePtr(), 207 cholesky_factor.valuePtr() + nnz, 208 output_values_ptr + batch_ptr_vec(batch_index)); 209 } 210 }); 211 212 // Create the CSRSparseMatrix instance from its component Tensors and 213 // prepare the Variant output Tensor. 214 CSRSparseMatrix output_csr_matrix; 215 OP_REQUIRES_OK( 216 ctx, 217 CSRSparseMatrix::CreateCSRSparseMatrix( 218 DataTypeToEnum<T>::value, input_matrix->dense_shape(), batch_ptr, 219 output_row_ptr, output_col_ind, output_values, &output_csr_matrix)); 220 Tensor* output_csr_matrix_tensor; 221 AllocatorAttributes cpu_alloc; 222 cpu_alloc.set_on_host(true); 223 OP_REQUIRES_OK( 224 ctx, ctx->allocate_output(0, TensorShape({}), &output_csr_matrix_tensor, 225 cpu_alloc)); 226 output_csr_matrix_tensor->scalar<Variant>()() = 227 std::move(output_csr_matrix); 228 } 229 230 private: ValidateInputs(const CSRSparseMatrix & sparse_matrix,const Tensor & permutation_indices,int * batch_size,int64_t * num_rows)231 Status ValidateInputs(const CSRSparseMatrix& sparse_matrix, 232 const Tensor& permutation_indices, int* batch_size, 233 int64_t* num_rows) { 234 if (sparse_matrix.dtype() != DataTypeToEnum<T>::value) 235 return errors::InvalidArgument( 236 "Asked for a CSRSparseMatrix of type ", 237 DataTypeString(DataTypeToEnum<T>::value), 238 " but saw dtype: ", DataTypeString(sparse_matrix.dtype())); 239 240 const Tensor& dense_shape = sparse_matrix.dense_shape(); 241 const int rank = dense_shape.dim_size(0); 242 if (rank < 2 || rank > 3) 243 return errors::InvalidArgument("sparse matrix must have rank 2 or 3; ", 244 "but dense_shape has size ", rank); 245 const int row_dim = (rank == 2) ? 0 : 1; 246 auto dense_shape_vec = dense_shape.vec<int64_t>(); 247 *num_rows = dense_shape_vec(row_dim); 248 const int64_t num_cols = dense_shape_vec(row_dim + 1); 249 if (*num_rows != num_cols) 250 return errors::InvalidArgument( 251 "sparse matrix must be square; got: ", *num_rows, " != ", num_cols); 252 const TensorShape& perm_shape = permutation_indices.shape(); 253 if (perm_shape.dims() + 1 != rank) 254 return errors::InvalidArgument( 255 "sparse matrix must have the same rank as permutation; got: ", rank, 256 " != ", perm_shape.dims(), " + 1."); 257 if (perm_shape.dim_size(rank - 2) != *num_rows) 258 return errors::InvalidArgument( 259 "permutation must have the same number of elements in each batch " 260 "as the number of rows in sparse matrix; got: ", 261 perm_shape.dim_size(rank - 2), " != ", *num_rows); 262 263 *batch_size = sparse_matrix.batch_size(); 264 if (*batch_size > 1) { 265 if (perm_shape.dim_size(0) != *batch_size) 266 return errors::InvalidArgument( 267 "permutation must have the same batch size " 268 "as sparse matrix; got: ", 269 perm_shape.dim_size(0), " != ", *batch_size); 270 } 271 272 return OkStatus(); 273 } 274 }; 275 276 #define REGISTER_CPU(T) \ 277 REGISTER_KERNEL_BUILDER(Name("SparseMatrixSparseCholesky") \ 278 .Device(DEVICE_CPU) \ 279 .TypeConstraint<T>("type"), \ 280 CSRSparseCholeskyCPUOp<T>); 281 REGISTER_CPU(float); 282 REGISTER_CPU(double); 283 REGISTER_CPU(complex64); 284 REGISTER_CPU(complex128); 285 286 #undef REGISTER_CPU 287 288 } // namespace tensorflow 289