• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <atomic>
17 #include <numeric>
18 #include <vector>
19 
20 #include "tensorflow/core/framework/op_requires.h"
21 
22 #define EIGEN_USE_THREADS
23 
24 #include "third_party/eigen3/Eigen/Core"
25 #include "third_party/eigen3/Eigen/SparseCholesky"
26 #include "third_party/eigen3/Eigen/SparseCore"
27 #include "third_party/eigen3/Eigen/OrderingMethods"
28 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
29 #include "tensorflow/core/framework/op.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/tensor_types.h"
32 #include "tensorflow/core/framework/variant_op_registry.h"
33 #include "tensorflow/core/kernels/sparse/kernels.h"
34 #include "tensorflow/core/kernels/sparse/sparse_matrix.h"
35 #include "tensorflow/core/util/work_sharder.h"
36 
37 namespace tensorflow {
38 
39 // Op to compute the sparse Cholesky factorization of a sparse matrix.
40 //
41 // Implements a CPU kernel which returns the lower triangular sparse Cholesky
42 // factor of a CSRSparseMatrix, using the fill-in reducing permutation.
43 //
44 // The CSRSparseMatrix may represent a single sparse matrix (rank 2) or a batch
45 // of sparse matrices (rank 3). Each component must represent a symmetric
46 // positive definite (SPD) matrix. In particular, this means the component
47 // matrices must be square. We don't actually check if the input is symmetric,
48 // only the lower triangular part of each component is read.
49 //
50 // The associated permutation must be a Tensor of rank (R - 1), where the
51 // CSRSparseMatrix has rank R. Additionally, the batch dimension of the
52 // CSRSparseMatrix and the permutation must be the same. Each batch of
53 // the permutation should the contain each of the integers [0,..,N - 1] exactly
54 // once, where N is the number of rows of each CSR SparseMatrix component.
55 // TODO(anudhyan): Add checks to throw an InvalidArgument error if the
56 // permutation is not valid.
57 //
58 // Returns a CSRSparseMatrix representing the lower triangular (batched)
59 // Cholesky factors. It has the same shape as the input CSRSparseMatrix. For
60 // each component sparse matrix A, the corresponding output sparse matrix L
61 // satisfies the identity:
62 //   A = L * Lt
63 // where Lt denotes the adjoint of L.
64 //
65 // TODO(b/126472741): Due to the multiple batches of a 3D CSRSparseMatrix being
66 // laid out in contiguous memory, this implementation allocates memory to store
67 // a temporary copy of the Cholesky factor. Consequently, it uses roughly twice
68 // the amount of memory that it needs to. This may cause a memory blowup for
69 // sparse matrices with a high number of non-zero elements.
70 template <typename T>
71 class CSRSparseCholeskyCPUOp : public OpKernel {
72   // Note: We operate in column major (CSC) format in this Op since the
73   // SimplicialLLT returns the factor in column major.
74   using SparseMatrix = Eigen::SparseMatrix<T, Eigen::ColMajor>;
75 
76  public:
CSRSparseCholeskyCPUOp(OpKernelConstruction * c)77   explicit CSRSparseCholeskyCPUOp(OpKernelConstruction* c) : OpKernel(c) {}
78 
Compute(OpKernelContext * ctx)79   void Compute(OpKernelContext* ctx) final {
80     // Extract inputs and validate shapes and types.
81     const CSRSparseMatrix* input_matrix;
82     OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &input_matrix));
83     const Tensor& input_permutation_indices = ctx->input(1);
84 
85     int64_t num_rows;
86     int batch_size;
87     OP_REQUIRES_OK(ctx, ValidateInputs(*input_matrix, input_permutation_indices,
88                                        &batch_size, &num_rows));
89 
90     // Allocate batch pointers.
91     Tensor batch_ptr(cpu_allocator(), DT_INT32, TensorShape({batch_size + 1}));
92     auto batch_ptr_vec = batch_ptr.vec<int32>();
93     batch_ptr_vec(0) = 0;
94 
95     // Temporary vector of Eigen SparseMatrices to store the Sparse Cholesky
96     // factors.
97     // Note: we use column-compressed (CSC) SparseMatrix because SimplicialLLT
98     // returns the factors in column major format. Since our input should be
99     // symmetric, column major and row major is identical in storage. We just
100     // have to switch to reading the upper triangular part of the input, which
101     // corresponds to the lower triangular part in row major format.
102     std::vector<SparseMatrix> sparse_cholesky_factors(batch_size);
103 
104     // TODO(anudhyan): Tune the cost per unit based on benchmarks.
105     const double nnz_per_row =
106         (input_matrix->total_nnz() / batch_size) / num_rows;
107     const int64_t sparse_cholesky_cost_per_batch =
108         nnz_per_row * nnz_per_row * num_rows;
109     // Perform sparse Cholesky factorization of each batch in parallel.
110     auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
111     std::atomic<int64_t> invalid_input_index(-1);
112     Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
113           sparse_cholesky_cost_per_batch,
114           [&](int64_t batch_begin, int64_t batch_end) {
115             for (int64_t batch_index = batch_begin; batch_index < batch_end;
116                  ++batch_index) {
117               // Define an Eigen SparseMatrix Map to operate on the
118               // CSRSparseMatrix component without copying the data.
119               Eigen::Map<const SparseMatrix> sparse_matrix(
120                   num_rows, num_rows, input_matrix->nnz(batch_index),
121                   input_matrix->row_pointers_vec(batch_index).data(),
122                   input_matrix->col_indices_vec(batch_index).data(),
123                   input_matrix->values_vec<T>(batch_index).data());
124 
125               Eigen::SimplicialLLT<SparseMatrix, Eigen::Upper,
126                                    Eigen::NaturalOrdering<int>>
127                   solver;
128               auto permutation_indices_flat =
129                   input_permutation_indices.flat<int32>().data();
130 
131               // Invert the fill-in reducing ordering and apply it to the input
132               // sparse matrix.
133               Eigen::Map<
134                   Eigen::PermutationMatrix<Eigen::Dynamic, Eigen::Dynamic, int>>
135                   permutation(permutation_indices_flat + batch_index * num_rows,
136                               num_rows);
137               auto permutation_inverse = permutation.inverse();
138 
139               SparseMatrix permuted_sparse_matrix;
140               permuted_sparse_matrix.template selfadjointView<Eigen::Upper>() =
141                   sparse_matrix.template selfadjointView<Eigen::Upper>()
142                       .twistedBy(permutation_inverse);
143 
144               // Compute the Cholesky decomposition.
145               solver.compute(permuted_sparse_matrix);
146               if (solver.info() != Eigen::Success) {
147                 invalid_input_index = batch_index;
148                 return;
149               }
150 
151               // Get the upper triangular factor, which would end up in the
152               // lower triangular part of the output CSRSparseMatrix when
153               // interpreted in row major format.
154               sparse_cholesky_factors[batch_index] =
155                   std::move(solver.matrixU());
156               // For now, batch_ptr contains the number of nonzeros in each
157               // batch.
158               batch_ptr_vec(batch_index + 1) =
159                   sparse_cholesky_factors[batch_index].nonZeros();
160             }
161           });
162 
163     // Check for invalid input.
164     OP_REQUIRES(
165         ctx, invalid_input_index == -1,
166         errors::InvalidArgument(
167             "Sparse Cholesky factorization failed for batch index ",
168             invalid_input_index.load(), ". The input might not be valid."));
169 
170     // Compute a cumulative sum to obtain the batch pointers.
171     std::partial_sum(batch_ptr_vec.data(),
172                      batch_ptr_vec.data() + batch_size + 1,
173                      batch_ptr_vec.data());
174 
175     // Allocate output Tensors.
176     const int64_t total_nnz = batch_ptr_vec(batch_size);
177     Tensor output_row_ptr(cpu_allocator(), DT_INT32,
178                           TensorShape({(num_rows + 1) * batch_size}));
179     Tensor output_col_ind(cpu_allocator(), DT_INT32, TensorShape({total_nnz}));
180     Tensor output_values(cpu_allocator(), DataTypeToEnum<T>::value,
181                          TensorShape({total_nnz}));
182     auto output_row_ptr_ptr = output_row_ptr.flat<int32>().data();
183     auto output_col_ind_ptr = output_col_ind.flat<int32>().data();
184     auto output_values_ptr = output_values.flat<T>().data();
185 
186     // Copy the output matrices from each batch into the CSRSparseMatrix
187     // Tensors.
188     // TODO(b/129906419): Factor out the copy from Eigen SparseMatrix to
189     // CSRSparseMatrix into common utils. This is also used in
190     // SparseMatrixSparseMatMul.
191     Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
192           (3 * total_nnz) / batch_size /* cost per unit */,
193           [&](int64_t batch_begin, int64_t batch_end) {
194             for (int64_t batch_index = batch_begin; batch_index < batch_end;
195                  ++batch_index) {
196               const SparseMatrix& cholesky_factor =
197                   sparse_cholesky_factors[batch_index];
198               const int64_t nnz = cholesky_factor.nonZeros();
199 
200               std::copy(cholesky_factor.outerIndexPtr(),
201                         cholesky_factor.outerIndexPtr() + num_rows + 1,
202                         output_row_ptr_ptr + batch_index * (num_rows + 1));
203               std::copy(cholesky_factor.innerIndexPtr(),
204                         cholesky_factor.innerIndexPtr() + nnz,
205                         output_col_ind_ptr + batch_ptr_vec(batch_index));
206               std::copy(cholesky_factor.valuePtr(),
207                         cholesky_factor.valuePtr() + nnz,
208                         output_values_ptr + batch_ptr_vec(batch_index));
209             }
210           });
211 
212     // Create the CSRSparseMatrix instance from its component Tensors and
213     // prepare the Variant output Tensor.
214     CSRSparseMatrix output_csr_matrix;
215     OP_REQUIRES_OK(
216         ctx,
217         CSRSparseMatrix::CreateCSRSparseMatrix(
218             DataTypeToEnum<T>::value, input_matrix->dense_shape(), batch_ptr,
219             output_row_ptr, output_col_ind, output_values, &output_csr_matrix));
220     Tensor* output_csr_matrix_tensor;
221     AllocatorAttributes cpu_alloc;
222     cpu_alloc.set_on_host(true);
223     OP_REQUIRES_OK(
224         ctx, ctx->allocate_output(0, TensorShape({}), &output_csr_matrix_tensor,
225                                   cpu_alloc));
226     output_csr_matrix_tensor->scalar<Variant>()() =
227         std::move(output_csr_matrix);
228   }
229 
230  private:
ValidateInputs(const CSRSparseMatrix & sparse_matrix,const Tensor & permutation_indices,int * batch_size,int64_t * num_rows)231   Status ValidateInputs(const CSRSparseMatrix& sparse_matrix,
232                         const Tensor& permutation_indices, int* batch_size,
233                         int64_t* num_rows) {
234     if (sparse_matrix.dtype() != DataTypeToEnum<T>::value)
235       return errors::InvalidArgument(
236           "Asked for a CSRSparseMatrix of type ",
237           DataTypeString(DataTypeToEnum<T>::value),
238           " but saw dtype: ", DataTypeString(sparse_matrix.dtype()));
239 
240     const Tensor& dense_shape = sparse_matrix.dense_shape();
241     const int rank = dense_shape.dim_size(0);
242     if (rank < 2 || rank > 3)
243       return errors::InvalidArgument("sparse matrix must have rank 2 or 3; ",
244                                      "but dense_shape has size ", rank);
245     const int row_dim = (rank == 2) ? 0 : 1;
246     auto dense_shape_vec = dense_shape.vec<int64_t>();
247     *num_rows = dense_shape_vec(row_dim);
248     const int64_t num_cols = dense_shape_vec(row_dim + 1);
249     if (*num_rows != num_cols)
250       return errors::InvalidArgument(
251           "sparse matrix must be square; got: ", *num_rows, " != ", num_cols);
252     const TensorShape& perm_shape = permutation_indices.shape();
253     if (perm_shape.dims() + 1 != rank)
254       return errors::InvalidArgument(
255           "sparse matrix must have the same rank as permutation; got: ", rank,
256           " != ", perm_shape.dims(), " + 1.");
257     if (perm_shape.dim_size(rank - 2) != *num_rows)
258       return errors::InvalidArgument(
259           "permutation must have the same number of elements in each batch "
260           "as the number of rows in sparse matrix; got: ",
261           perm_shape.dim_size(rank - 2), " != ", *num_rows);
262 
263     *batch_size = sparse_matrix.batch_size();
264     if (*batch_size > 1) {
265       if (perm_shape.dim_size(0) != *batch_size)
266         return errors::InvalidArgument(
267             "permutation must have the same batch size "
268             "as sparse matrix; got: ",
269             perm_shape.dim_size(0), " != ", *batch_size);
270     }
271 
272     return OkStatus();
273   }
274 };
275 
276 #define REGISTER_CPU(T)                                      \
277   REGISTER_KERNEL_BUILDER(Name("SparseMatrixSparseCholesky") \
278                               .Device(DEVICE_CPU)            \
279                               .TypeConstraint<T>("type"),    \
280                           CSRSparseCholeskyCPUOp<T>);
281 REGISTER_CPU(float);
282 REGISTER_CPU(double);
283 REGISTER_CPU(complex64);
284 REGISTER_CPU(complex128);
285 
286 #undef REGISTER_CPU
287 
288 }  // namespace tensorflow
289