• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/linalg_ops.cc.
17 
18 #include "tensorflow/core/framework/kernel_def_builder.h"
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/framework/register_types.h"
21 #include "tensorflow/core/framework/tensor_shape.h"
22 #include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/platform/types.h"
25 
26 namespace tensorflow {
27 
28 static const char kNotInvertibleMsg[] = "The matrix is not invertible.";
29 
30 static const char kNotInvertibleScalarMsg[] =
31     "The matrix is not invertible: it is a scalar with value zero.";
32 
33 static const char kThomasFailedMsg[] =
34     "The matrix is either not invertible, or requires pivoting. "
35     "Try setting partial_pivoting = True.";
36 
37 template <class Scalar>
38 class TridiagonalSolveOp : public LinearAlgebraOp<Scalar> {
39  public:
40   INHERIT_LINALG_TYPEDEFS(Scalar);
41   using MatrixMapRow =
42       decltype(std::declval<const ConstMatrixMaps>()[0].row(0));
43 
TridiagonalSolveOp(OpKernelConstruction * context)44   explicit TridiagonalSolveOp(OpKernelConstruction* context) : Base(context) {
45     OP_REQUIRES_OK(context, context->GetAttr("partial_pivoting", &pivoting_));
46     perturb_singular_ = false;
47     if (context->HasAttr("perturb_singular")) {
48       OP_REQUIRES_OK(context,
49                      context->GetAttr("perturb_singular", &perturb_singular_));
50     }
51     OP_REQUIRES(context, pivoting_ || !perturb_singular_,
52                 errors::InvalidArgument("Setting perturb_singular requires "
53                                         "also setting partial_pivoting."));
54   }
55 
ValidateInputMatrixShapes(OpKernelContext * context,const TensorShapes & input_matrix_shapes) const56   void ValidateInputMatrixShapes(
57       OpKernelContext* context,
58       const TensorShapes& input_matrix_shapes) const final {
59     auto num_inputs = input_matrix_shapes.size();
60     OP_REQUIRES(context, num_inputs == 2,
61                 errors::InvalidArgument("Expected two input matrices, got ",
62                                         num_inputs, "."));
63 
64     auto num_diags = input_matrix_shapes[0].dim_size(0);
65     OP_REQUIRES(
66         context, num_diags == 3,
67         errors::InvalidArgument("Expected diagonals to be provided as a "
68                                 "matrix with 3 rows, got ",
69                                 num_diags, " rows."));
70 
71     auto num_eqs_left = input_matrix_shapes[0].dim_size(1);
72     auto num_eqs_right = input_matrix_shapes[1].dim_size(0);
73     OP_REQUIRES(
74         context, num_eqs_left == num_eqs_right,
75         errors::InvalidArgument("Expected the same number of left-hand sides "
76                                 "and right-hand sides, got ",
77                                 num_eqs_left, " and ", num_eqs_right, "."));
78   }
79 
GetOutputMatrixShapes(const TensorShapes & input_matrix_shapes) const80   TensorShapes GetOutputMatrixShapes(
81       const TensorShapes& input_matrix_shapes) const final {
82     return TensorShapes({input_matrix_shapes[1]});
83   }
84 
GetCostPerUnit(const TensorShapes & input_matrix_shapes) const85   int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final {
86     const int num_eqs = static_cast<int>(input_matrix_shapes[0].dim_size(1));
87     const int num_rhss = static_cast<int>(input_matrix_shapes[1].dim_size(0));
88 
89     const double add_cost = Eigen::TensorOpCost::AddCost<Scalar>();
90     const double mult_cost = Eigen::TensorOpCost::MulCost<Scalar>();
91     const double div_cost = Eigen::TensorOpCost::DivCost<Scalar>();
92 
93     double cost;
94     if (pivoting_) {
95       // Assuming cases with and without row interchange are equiprobable.
96       cost = num_eqs * (div_cost * (num_rhss + 1) +
97                         (add_cost + mult_cost) * (2.5 * num_rhss + 1.5));
98     } else {
99       cost = num_eqs * (div_cost * (num_rhss + 1) +
100                         (add_cost + mult_cost) * (2 * num_rhss + 1));
101     }
102     return cost >= static_cast<double>(kint64max) ? kint64max
103                                                   : static_cast<int64>(cost);
104   }
105 
EnableInputForwarding() const106   bool EnableInputForwarding() const final { return false; }
107 
ComputeMatrix(OpKernelContext * context,const ConstMatrixMaps & inputs,MatrixMaps * outputs)108   void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
109                      MatrixMaps* outputs) final {
110     const auto diagonals = inputs[0];
111 
112     // Superdiagonal elements, first is ignored.
113     const auto& superdiag = diagonals.row(0);
114     // Diagonal elements.
115     const auto& diag = diagonals.row(1);
116     // Subdiagonal elements, n-th is ignored.
117     const auto& subdiag = diagonals.row(2);
118     // Right-hand sides.
119     const auto& rhs = inputs[1];
120 
121     const int n = diag.size();
122     MatrixMap& x = outputs->at(0);
123     constexpr Scalar zero(0);
124 
125     if (n == 0) {
126       return;
127     }
128     if (pivoting_ && perturb_singular_) {
129       SolveWithGaussianEliminationWithPivotingAndPerturbSingular(
130           context, superdiag, diag, subdiag, rhs, x);
131       return;
132     }
133 
134     if (n == 1) {
135       if (diag(0) == zero) {
136         LOG(WARNING) << kNotInvertibleScalarMsg;
137         x.fill(std::numeric_limits<Scalar>::quiet_NaN());
138       } else {
139         x.row(0) = rhs.row(0) / diag(0);
140       }
141       return;
142     }
143 
144     if (pivoting_) {
145       SolveWithGaussianEliminationWithPivoting(context, superdiag, diag,
146                                                subdiag, rhs, x);
147     } else {
148       SolveWithThomasAlgorithm(context, superdiag, diag, subdiag, rhs, x);
149     }
150   }
151 
152  private:
153   TF_DISALLOW_COPY_AND_ASSIGN(TridiagonalSolveOp);
154 
155   // Adjust pivot such that neither 'rhs[i,:] / pivot' nor '1 / pivot' cause
156   // overflow, where i numerates the multiple right-hand-sides. During the
157   // back-substitution phase in
158   // SolveWithGaussianEliminationWithPivotingAndPerturbSingular, we compute
159   // the i'th row of the solution as rhs[i,:] * (1 / pivot). This logic is
160   // extracted from the LAPACK routine xLAGTS.
MaybePerturbPivot(RealScalar perturb,Scalar & pivot,Eigen::Matrix<Scalar,1,Eigen::Dynamic> & rhs_row)161   void MaybePerturbPivot(RealScalar perturb, Scalar& pivot,
162                          Eigen::Matrix<Scalar, 1, Eigen::Dynamic>& rhs_row) {
163     constexpr RealScalar one(1);
164     // The following logic is extracted from xLAMCH in LAPACK.
165     constexpr RealScalar tiny = std::numeric_limits<RealScalar>::min();
166     constexpr RealScalar small = one / std::numeric_limits<RealScalar>::max();
167     constexpr RealScalar safemin =
168         (small < tiny
169              ? tiny
170              : (one + std::numeric_limits<RealScalar>::epsilon()) * safemin);
171     constexpr RealScalar bignum = one / safemin;
172 
173     RealScalar abs_pivot = std::abs(pivot);
174     if (abs_pivot >= one) {
175       return;
176     }
177     // Safeguard against infinite loop if 'perturb' is zero.
178     // 'perturb' should never have magnitude smaller than safemin.
179     perturb = std::max(std::abs(perturb), safemin);
180     // Make sure perturb and pivot have the same sign.
181     perturb = std::copysign(perturb, std::real(pivot));
182 
183     bool stop = false;
184     const RealScalar max_factor = rhs_row.array().abs().maxCoeff();
185     while (abs_pivot < one && !stop) {
186       if (abs_pivot < safemin) {
187         if (abs_pivot == 0 || max_factor * safemin > abs_pivot) {
188           pivot += perturb;
189           perturb *= 2;
190         } else {
191           pivot *= bignum;
192           rhs_row *= bignum;
193           stop = true;
194         }
195       } else if (max_factor > abs_pivot * bignum) {
196         pivot += perturb;
197         perturb *= 2;
198       } else {
199         stop = true;
200       }
201       abs_pivot = std::abs(pivot);
202     }
203   }
204 
205   // This function roughly follows LAPACK's xLAGTF + xLAGTS routines.
206   //
207   // It computes the solution to the a linear system with multiple
208   // right-hand sides
209   //     T * X = RHS
210   // where T is a tridiagonal matrix using a row-pivoted LU decomposition.
211 
212   // This routine differs from SolveWithGaussianEliminationWithPivoting by
213   // allowing the tridiagonal matrix to be numerically singular.
214   // If tiny diagonal elements of U are encountered, signaling that T is
215   // numerically singular, the diagonal elements are perturbed by
216   // an amount proportional to eps*max_abs_u to avoid overflow, where
217   // max_abs_u is max_{i,j} | U(i,j) |. This is useful when using this
218   // routine for computing eigenvectors of a matrix T' via inverse
219   // iteration by solving the singular system
220   //   (T' - lambda*I) X = RHS,
221   // where lambda is an eigenvalue of T'.
222   //
223   // By fusing the factorization and solution, we avoid storing L
224   // and pivoting information, and the forward solve is done on-the-fly
225   // during factorization, instead of requiring a separate loop.
SolveWithGaussianEliminationWithPivotingAndPerturbSingular(OpKernelContext * context,const MatrixMapRow & superdiag,const MatrixMapRow & diag,const MatrixMapRow & subdiag,const ConstMatrixMap & rhs,MatrixMap & x)226   void SolveWithGaussianEliminationWithPivotingAndPerturbSingular(
227       OpKernelContext* context, const MatrixMapRow& superdiag,
228       const MatrixMapRow& diag, const MatrixMapRow& subdiag,
229       const ConstMatrixMap& rhs, MatrixMap& x) {
230     constexpr Scalar zero(0);
231     constexpr RealScalar realzero(0);
232     constexpr Scalar one(1);
233     constexpr RealScalar eps = std::numeric_limits<RealScalar>::epsilon();
234 
235     const int n = diag.size();
236     if (n == 0) return;
237     if (n == 1) {
238       Scalar denom = diag(0);
239       RealScalar tol = eps * std::abs(denom);
240       Eigen::Matrix<Scalar, 1, Eigen::Dynamic> row = rhs.row(0);
241       MaybePerturbPivot(tol, denom, row);
242       x = row * (one / denom);
243       return;
244     }
245 
246     // The three columns in u are the diagonal, superdiagonal, and second
247     // superdiagonal, respectively, of the U matrix in the LU decomposition
248     // of the input matrix (subject to row exchanges due to pivoting). For
249     // a pivoted tridiagonal matrix, the U matrix has at most two non-zero
250     // superdiagonals.
251     Eigen::Array<Scalar, Eigen::Dynamic, 3> u(n, 3);
252 
253     // We accumulate max( abs( U(i,j) ) ) in max_abs_u for use in perturbing
254     // near-zero pivots during the solution phase.
255     u(0, 0) = diag(0);
256     u(0, 1) = superdiag(0);
257     RealScalar max_abs_u = std::max(std::abs(u(0, 0)), std::abs(u(0, 1)));
258     RealScalar scale1 = std::abs(u(0, 0)) + std::abs(u(0, 1));
259     x.row(0) = rhs.row(0);
260     for (int k = 0; k < n - 1; ++k) {
261       // The non-zeros in the (k+1)-st row are
262       //    [ ... subdiag(k+1) (diag(k+1)-shift) superdiag(k+1) ... ]
263       u(k + 1, 0) = diag(k + 1);
264       RealScalar scale2 = std::abs(subdiag(k + 1)) + std::abs(u(k + 1, 0));
265       if (k < n - 2) scale2 += std::abs(superdiag(k + 1));
266       if (subdiag(k + 1) == zero) {
267         // The sub-diagonal in the k+1 row is already zero. Move to the next
268         // row.
269         scale1 = scale2;
270         u(k + 1, 1) = superdiag(k + 1);
271         u(k, 2) = zero;
272         x.row(k + 1) = rhs.row(k + 1);
273       } else {
274         const RealScalar piv1 =
275             u(k, 0) == zero ? realzero : std::abs(u(k, 0)) / scale1;
276         const RealScalar piv2 = std::abs(subdiag(k + 1)) / scale2;
277         if (piv2 <= piv1) {
278           // No row pivoting needed.
279           scale1 = scale2;
280           Scalar factor = subdiag(k + 1) / u(k, 0);
281           u(k + 1, 0) = diag(k + 1) - factor * u(k, 1);
282           u(k + 1, 1) = superdiag(k + 1);
283           u(k, 2) = zero;
284           x.row(k + 1) = rhs.row(k + 1) - factor * x.row(k);
285         } else {
286           // Swap rows k and k+1.
287           Scalar factor = u(k, 0) / subdiag(k + 1);
288           u(k, 0) = subdiag(k + 1);
289           u(k + 1, 0) = u(k, 1) - factor * diag(k + 1);
290           u(k, 1) = diag(k + 1);
291           if (k < n - 2) {
292             u(k, 2) = superdiag(k + 1);
293             u(k + 1, 1) = -factor * superdiag(k + 1);
294           }
295           x.row(k + 1) = x.row(k) - factor * rhs.row(k + 1);
296           x.row(k) = rhs.row(k + 1);
297         }
298       }
299       if (k < n - 2) {
300         for (int i = 0; i < 3; ++i) {
301           max_abs_u = std::max(max_abs_u, std::abs(u(k, i)));
302         }
303       }
304     }
305     max_abs_u = std::max(max_abs_u, std::abs(u(n - 1, 0)));
306 
307     // We have already solved L z = P rhs above. Now we solve U x = z,
308     // possibly perturbing small pivots to avoid overflow. The variable tol
309     // contains eps * max( abs( u(:,:) ) ). If tiny pivots are encoutered,
310     // they are perturbed by a small amount on the scale of tol to avoid
311     // overflow or scaled up to avoid underflow.
312     RealScalar tol = eps * max_abs_u;
313     Scalar denom = u(n - 1, 0);
314     Eigen::Matrix<Scalar, 1, Eigen::Dynamic> row = x.row(n - 1);
315     MaybePerturbPivot(tol, denom, row);
316     x.row(n - 1) = row * (one / denom);
317     if (n > 1) {
318       denom = u(n - 2, 0);
319       row = x.row(n - 2) - u(n - 2, 1) * x.row(n - 1);
320       MaybePerturbPivot(std::copysign(tol, std::real(denom)), denom, row);
321       x.row(n - 2) = row * (one / denom);
322 
323       for (int k = n - 3; k >= 0; --k) {
324         row = x.row(k) - u(k, 1) * x.row(k + 1) - u(k, 2) * x.row(k + 2);
325         denom = u(k, 0);
326         MaybePerturbPivot(std::copysign(tol, std::real(denom)), denom, row);
327         x.row(k) = row * (one / denom);
328       }
329     }
330   }
331 
SolveWithGaussianEliminationWithPivoting(OpKernelContext * context,const MatrixMapRow & superdiag,const MatrixMapRow & diag,const MatrixMapRow & subdiag,const ConstMatrixMap & rhs,MatrixMap & x)332   void SolveWithGaussianEliminationWithPivoting(OpKernelContext* context,
333                                                 const MatrixMapRow& superdiag,
334                                                 const MatrixMapRow& diag,
335                                                 const MatrixMapRow& subdiag,
336                                                 const ConstMatrixMap& rhs,
337                                                 MatrixMap& x) {
338     const int n = diag.size();
339     const Scalar zero(0);
340 
341     // The three columns in u are the diagonal, superdiagonal, and second
342     // superdiagonal, respectively, of the U matrix in the LU decomposition of
343     // the input matrix (subject to row exchanges due to pivoting). For pivoted
344     // tridiagonal matrix, the U matrix has at most two non-zero superdiagonals.
345     Eigen::Array<Scalar, Eigen::Dynamic, 3> u(n, 3);
346 
347     // The code below roughly follows LAPACK's dgtsv routine, with main
348     // difference being not overwriting the input.
349     u(0, 0) = diag(0);
350     u(0, 1) = superdiag(0);
351     x.row(0) = rhs.row(0);
352     for (int i = 0; i < n - 1; ++i) {
353       if (std::abs(u(i)) >= std::abs(subdiag(i + 1))) {
354         // No row interchange.
355         if (u(i) == zero) {
356           LOG(WARNING) << kNotInvertibleMsg;
357           x.fill(std::numeric_limits<Scalar>::quiet_NaN());
358           return;
359         }
360         const Scalar factor = subdiag(i + 1) / u(i, 0);
361         u(i + 1, 0) = diag(i + 1) - factor * u(i, 1);
362         x.row(i + 1) = rhs.row(i + 1) - factor * x.row(i);
363         if (i != n - 2) {
364           u(i + 1, 1) = superdiag(i + 1);
365           u(i, 2) = 0;
366         }
367       } else {
368         // Interchange rows i and i + 1.
369         const Scalar factor = u(i, 0) / subdiag(i + 1);
370         u(i, 0) = subdiag(i + 1);
371         u(i + 1, 0) = u(i, 1) - factor * diag(i + 1);
372         u(i, 1) = diag(i + 1);
373         x.row(i + 1) = x.row(i) - factor * rhs.row(i + 1);
374         x.row(i) = rhs.row(i + 1);
375         if (i != n - 2) {
376           u(i, 2) = superdiag(i + 1);
377           u(i + 1, 1) = -factor * superdiag(i + 1);
378         }
379       }
380     }
381     if (u(n - 1, 0) == zero) {
382       LOG(WARNING) << kNotInvertibleMsg;
383       x.fill(std::numeric_limits<Scalar>::quiet_NaN());
384       return;
385     }
386     x.row(n - 1) /= u(n - 1, 0);
387     x.row(n - 2) = (x.row(n - 2) - u(n - 2, 1) * x.row(n - 1)) / u(n - 2, 0);
388     for (int i = n - 3; i >= 0; --i) {
389       x.row(i) = (x.row(i) - u(i, 1) * x.row(i + 1) - u(i, 2) * x.row(i + 2)) /
390                  u(i, 0);
391     }
392   }
393 
SolveWithThomasAlgorithm(OpKernelContext * context,const MatrixMapRow & superdiag,const MatrixMapRow & diag,const MatrixMapRow & subdiag,const ConstMatrixMap & rhs,MatrixMap & x)394   void SolveWithThomasAlgorithm(OpKernelContext* context,
395                                 const MatrixMapRow& superdiag,
396                                 const MatrixMapRow& diag,
397                                 const MatrixMapRow& subdiag,
398                                 const ConstMatrixMap& rhs, MatrixMap& x) {
399     const int n = diag.size();
400     const Scalar zero(0);
401 
402     // The superdiagonal of the U matrix in the LU decomposition of the input
403     // matrix (in Thomas algorithm, the U matrix has ones on the diagonal and
404     // one superdiagonal).
405     Eigen::Matrix<Scalar, Eigen::Dynamic, 1> u(n);
406 
407     if (diag(0) == zero) {
408       LOG(WARNING) << kThomasFailedMsg;
409       x.fill(std::numeric_limits<Scalar>::quiet_NaN());
410       return;
411     }
412 
413     u(0) = superdiag(0) / diag(0);
414     x.row(0) = rhs.row(0) / diag(0);
415     for (int i = 1; i < n; ++i) {
416       auto denom = diag(i) - subdiag(i) * u(i - 1);
417       if (denom == zero) {
418         LOG(WARNING) << kThomasFailedMsg;
419         x.fill(std::numeric_limits<Scalar>::quiet_NaN());
420         return;
421       }
422       u(i) = superdiag(i) / denom;
423       x.row(i) = (rhs.row(i) - subdiag(i) * x.row(i - 1)) / denom;
424     }
425     for (int i = n - 2; i >= 0; --i) {
426       x.row(i) -= u(i) * x.row(i + 1);
427     }
428   }
429 
430   bool pivoting_;
431   bool perturb_singular_;
432 };
433 
434 REGISTER_LINALG_OP_CPU("TridiagonalSolve", (TridiagonalSolveOp<float>), float);
435 REGISTER_LINALG_OP_CPU("TridiagonalSolve", (TridiagonalSolveOp<double>),
436                        double);
437 REGISTER_LINALG_OP_CPU("TridiagonalSolve", (TridiagonalSolveOp<complex64>),
438                        complex64);
439 REGISTER_LINALG_OP_CPU("TridiagonalSolve", (TridiagonalSolveOp<complex128>),
440                        complex128);
441 }  // namespace tensorflow
442