1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/linalg_ops.cc. 17 18 #include "tensorflow/core/framework/kernel_def_builder.h" 19 #include "tensorflow/core/framework/op_kernel.h" 20 #include "tensorflow/core/framework/register_types.h" 21 #include "tensorflow/core/framework/tensor_shape.h" 22 #include "tensorflow/core/kernels/linalg/linalg_ops_common.h" 23 #include "tensorflow/core/lib/core/errors.h" 24 #include "tensorflow/core/platform/types.h" 25 26 namespace tensorflow { 27 28 static const char kNotInvertibleMsg[] = "The matrix is not invertible."; 29 30 static const char kNotInvertibleScalarMsg[] = 31 "The matrix is not invertible: it is a scalar with value zero."; 32 33 static const char kThomasFailedMsg[] = 34 "The matrix is either not invertible, or requires pivoting. " 35 "Try setting partial_pivoting = True."; 36 37 template <class Scalar> 38 class TridiagonalSolveOp : public LinearAlgebraOp<Scalar> { 39 public: 40 INHERIT_LINALG_TYPEDEFS(Scalar); 41 using MatrixMapRow = 42 decltype(std::declval<const ConstMatrixMaps>()[0].row(0)); 43 TridiagonalSolveOp(OpKernelConstruction * context)44 explicit TridiagonalSolveOp(OpKernelConstruction* context) : Base(context) { 45 OP_REQUIRES_OK(context, context->GetAttr("partial_pivoting", &pivoting_)); 46 perturb_singular_ = false; 47 if (context->HasAttr("perturb_singular")) { 48 OP_REQUIRES_OK(context, 49 context->GetAttr("perturb_singular", &perturb_singular_)); 50 } 51 OP_REQUIRES(context, pivoting_ || !perturb_singular_, 52 errors::InvalidArgument("Setting perturb_singular requires " 53 "also setting partial_pivoting.")); 54 } 55 ValidateInputMatrixShapes(OpKernelContext * context,const TensorShapes & input_matrix_shapes) const56 void ValidateInputMatrixShapes( 57 OpKernelContext* context, 58 const TensorShapes& input_matrix_shapes) const final { 59 auto num_inputs = input_matrix_shapes.size(); 60 OP_REQUIRES(context, num_inputs == 2, 61 errors::InvalidArgument("Expected two input matrices, got ", 62 num_inputs, ".")); 63 64 auto num_diags = input_matrix_shapes[0].dim_size(0); 65 OP_REQUIRES( 66 context, num_diags == 3, 67 errors::InvalidArgument("Expected diagonals to be provided as a " 68 "matrix with 3 rows, got ", 69 num_diags, " rows.")); 70 71 auto num_eqs_left = input_matrix_shapes[0].dim_size(1); 72 auto num_eqs_right = input_matrix_shapes[1].dim_size(0); 73 OP_REQUIRES( 74 context, num_eqs_left == num_eqs_right, 75 errors::InvalidArgument("Expected the same number of left-hand sides " 76 "and right-hand sides, got ", 77 num_eqs_left, " and ", num_eqs_right, ".")); 78 } 79 GetOutputMatrixShapes(const TensorShapes & input_matrix_shapes) const80 TensorShapes GetOutputMatrixShapes( 81 const TensorShapes& input_matrix_shapes) const final { 82 return TensorShapes({input_matrix_shapes[1]}); 83 } 84 GetCostPerUnit(const TensorShapes & input_matrix_shapes) const85 int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final { 86 const int num_eqs = static_cast<int>(input_matrix_shapes[0].dim_size(1)); 87 const int num_rhss = static_cast<int>(input_matrix_shapes[1].dim_size(0)); 88 89 const double add_cost = Eigen::TensorOpCost::AddCost<Scalar>(); 90 const double mult_cost = Eigen::TensorOpCost::MulCost<Scalar>(); 91 const double div_cost = Eigen::TensorOpCost::DivCost<Scalar>(); 92 93 double cost; 94 if (pivoting_) { 95 // Assuming cases with and without row interchange are equiprobable. 96 cost = num_eqs * (div_cost * (num_rhss + 1) + 97 (add_cost + mult_cost) * (2.5 * num_rhss + 1.5)); 98 } else { 99 cost = num_eqs * (div_cost * (num_rhss + 1) + 100 (add_cost + mult_cost) * (2 * num_rhss + 1)); 101 } 102 return cost >= static_cast<double>(kint64max) ? kint64max 103 : static_cast<int64>(cost); 104 } 105 EnableInputForwarding() const106 bool EnableInputForwarding() const final { return false; } 107 ComputeMatrix(OpKernelContext * context,const ConstMatrixMaps & inputs,MatrixMaps * outputs)108 void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs, 109 MatrixMaps* outputs) final { 110 const auto diagonals = inputs[0]; 111 112 // Superdiagonal elements, first is ignored. 113 const auto& superdiag = diagonals.row(0); 114 // Diagonal elements. 115 const auto& diag = diagonals.row(1); 116 // Subdiagonal elements, n-th is ignored. 117 const auto& subdiag = diagonals.row(2); 118 // Right-hand sides. 119 const auto& rhs = inputs[1]; 120 121 const int n = diag.size(); 122 MatrixMap& x = outputs->at(0); 123 constexpr Scalar zero(0); 124 125 if (n == 0) { 126 return; 127 } 128 if (pivoting_ && perturb_singular_) { 129 SolveWithGaussianEliminationWithPivotingAndPerturbSingular( 130 context, superdiag, diag, subdiag, rhs, x); 131 return; 132 } 133 134 if (n == 1) { 135 if (diag(0) == zero) { 136 LOG(WARNING) << kNotInvertibleScalarMsg; 137 x.fill(std::numeric_limits<Scalar>::quiet_NaN()); 138 } else { 139 x.row(0) = rhs.row(0) / diag(0); 140 } 141 return; 142 } 143 144 if (pivoting_) { 145 SolveWithGaussianEliminationWithPivoting(context, superdiag, diag, 146 subdiag, rhs, x); 147 } else { 148 SolveWithThomasAlgorithm(context, superdiag, diag, subdiag, rhs, x); 149 } 150 } 151 152 private: 153 TF_DISALLOW_COPY_AND_ASSIGN(TridiagonalSolveOp); 154 155 // Adjust pivot such that neither 'rhs[i,:] / pivot' nor '1 / pivot' cause 156 // overflow, where i numerates the multiple right-hand-sides. During the 157 // back-substitution phase in 158 // SolveWithGaussianEliminationWithPivotingAndPerturbSingular, we compute 159 // the i'th row of the solution as rhs[i,:] * (1 / pivot). This logic is 160 // extracted from the LAPACK routine xLAGTS. MaybePerturbPivot(RealScalar perturb,Scalar & pivot,Eigen::Matrix<Scalar,1,Eigen::Dynamic> & rhs_row)161 void MaybePerturbPivot(RealScalar perturb, Scalar& pivot, 162 Eigen::Matrix<Scalar, 1, Eigen::Dynamic>& rhs_row) { 163 constexpr RealScalar one(1); 164 // The following logic is extracted from xLAMCH in LAPACK. 165 constexpr RealScalar tiny = std::numeric_limits<RealScalar>::min(); 166 constexpr RealScalar small = one / std::numeric_limits<RealScalar>::max(); 167 constexpr RealScalar safemin = 168 (small < tiny 169 ? tiny 170 : (one + std::numeric_limits<RealScalar>::epsilon()) * safemin); 171 constexpr RealScalar bignum = one / safemin; 172 173 RealScalar abs_pivot = std::abs(pivot); 174 if (abs_pivot >= one) { 175 return; 176 } 177 // Safeguard against infinite loop if 'perturb' is zero. 178 // 'perturb' should never have magnitude smaller than safemin. 179 perturb = std::max(std::abs(perturb), safemin); 180 // Make sure perturb and pivot have the same sign. 181 perturb = std::copysign(perturb, std::real(pivot)); 182 183 bool stop = false; 184 const RealScalar max_factor = rhs_row.array().abs().maxCoeff(); 185 while (abs_pivot < one && !stop) { 186 if (abs_pivot < safemin) { 187 if (abs_pivot == 0 || max_factor * safemin > abs_pivot) { 188 pivot += perturb; 189 perturb *= 2; 190 } else { 191 pivot *= bignum; 192 rhs_row *= bignum; 193 stop = true; 194 } 195 } else if (max_factor > abs_pivot * bignum) { 196 pivot += perturb; 197 perturb *= 2; 198 } else { 199 stop = true; 200 } 201 abs_pivot = std::abs(pivot); 202 } 203 } 204 205 // This function roughly follows LAPACK's xLAGTF + xLAGTS routines. 206 // 207 // It computes the solution to the a linear system with multiple 208 // right-hand sides 209 // T * X = RHS 210 // where T is a tridiagonal matrix using a row-pivoted LU decomposition. 211 212 // This routine differs from SolveWithGaussianEliminationWithPivoting by 213 // allowing the tridiagonal matrix to be numerically singular. 214 // If tiny diagonal elements of U are encountered, signaling that T is 215 // numerically singular, the diagonal elements are perturbed by 216 // an amount proportional to eps*max_abs_u to avoid overflow, where 217 // max_abs_u is max_{i,j} | U(i,j) |. This is useful when using this 218 // routine for computing eigenvectors of a matrix T' via inverse 219 // iteration by solving the singular system 220 // (T' - lambda*I) X = RHS, 221 // where lambda is an eigenvalue of T'. 222 // 223 // By fusing the factorization and solution, we avoid storing L 224 // and pivoting information, and the forward solve is done on-the-fly 225 // during factorization, instead of requiring a separate loop. SolveWithGaussianEliminationWithPivotingAndPerturbSingular(OpKernelContext * context,const MatrixMapRow & superdiag,const MatrixMapRow & diag,const MatrixMapRow & subdiag,const ConstMatrixMap & rhs,MatrixMap & x)226 void SolveWithGaussianEliminationWithPivotingAndPerturbSingular( 227 OpKernelContext* context, const MatrixMapRow& superdiag, 228 const MatrixMapRow& diag, const MatrixMapRow& subdiag, 229 const ConstMatrixMap& rhs, MatrixMap& x) { 230 constexpr Scalar zero(0); 231 constexpr RealScalar realzero(0); 232 constexpr Scalar one(1); 233 constexpr RealScalar eps = std::numeric_limits<RealScalar>::epsilon(); 234 235 const int n = diag.size(); 236 if (n == 0) return; 237 if (n == 1) { 238 Scalar denom = diag(0); 239 RealScalar tol = eps * std::abs(denom); 240 Eigen::Matrix<Scalar, 1, Eigen::Dynamic> row = rhs.row(0); 241 MaybePerturbPivot(tol, denom, row); 242 x = row * (one / denom); 243 return; 244 } 245 246 // The three columns in u are the diagonal, superdiagonal, and second 247 // superdiagonal, respectively, of the U matrix in the LU decomposition 248 // of the input matrix (subject to row exchanges due to pivoting). For 249 // a pivoted tridiagonal matrix, the U matrix has at most two non-zero 250 // superdiagonals. 251 Eigen::Array<Scalar, Eigen::Dynamic, 3> u(n, 3); 252 253 // We accumulate max( abs( U(i,j) ) ) in max_abs_u for use in perturbing 254 // near-zero pivots during the solution phase. 255 u(0, 0) = diag(0); 256 u(0, 1) = superdiag(0); 257 RealScalar max_abs_u = std::max(std::abs(u(0, 0)), std::abs(u(0, 1))); 258 RealScalar scale1 = std::abs(u(0, 0)) + std::abs(u(0, 1)); 259 x.row(0) = rhs.row(0); 260 for (int k = 0; k < n - 1; ++k) { 261 // The non-zeros in the (k+1)-st row are 262 // [ ... subdiag(k+1) (diag(k+1)-shift) superdiag(k+1) ... ] 263 u(k + 1, 0) = diag(k + 1); 264 RealScalar scale2 = std::abs(subdiag(k + 1)) + std::abs(u(k + 1, 0)); 265 if (k < n - 2) scale2 += std::abs(superdiag(k + 1)); 266 if (subdiag(k + 1) == zero) { 267 // The sub-diagonal in the k+1 row is already zero. Move to the next 268 // row. 269 scale1 = scale2; 270 u(k + 1, 1) = superdiag(k + 1); 271 u(k, 2) = zero; 272 x.row(k + 1) = rhs.row(k + 1); 273 } else { 274 const RealScalar piv1 = 275 u(k, 0) == zero ? realzero : std::abs(u(k, 0)) / scale1; 276 const RealScalar piv2 = std::abs(subdiag(k + 1)) / scale2; 277 if (piv2 <= piv1) { 278 // No row pivoting needed. 279 scale1 = scale2; 280 Scalar factor = subdiag(k + 1) / u(k, 0); 281 u(k + 1, 0) = diag(k + 1) - factor * u(k, 1); 282 u(k + 1, 1) = superdiag(k + 1); 283 u(k, 2) = zero; 284 x.row(k + 1) = rhs.row(k + 1) - factor * x.row(k); 285 } else { 286 // Swap rows k and k+1. 287 Scalar factor = u(k, 0) / subdiag(k + 1); 288 u(k, 0) = subdiag(k + 1); 289 u(k + 1, 0) = u(k, 1) - factor * diag(k + 1); 290 u(k, 1) = diag(k + 1); 291 if (k < n - 2) { 292 u(k, 2) = superdiag(k + 1); 293 u(k + 1, 1) = -factor * superdiag(k + 1); 294 } 295 x.row(k + 1) = x.row(k) - factor * rhs.row(k + 1); 296 x.row(k) = rhs.row(k + 1); 297 } 298 } 299 if (k < n - 2) { 300 for (int i = 0; i < 3; ++i) { 301 max_abs_u = std::max(max_abs_u, std::abs(u(k, i))); 302 } 303 } 304 } 305 max_abs_u = std::max(max_abs_u, std::abs(u(n - 1, 0))); 306 307 // We have already solved L z = P rhs above. Now we solve U x = z, 308 // possibly perturbing small pivots to avoid overflow. The variable tol 309 // contains eps * max( abs( u(:,:) ) ). If tiny pivots are encoutered, 310 // they are perturbed by a small amount on the scale of tol to avoid 311 // overflow or scaled up to avoid underflow. 312 RealScalar tol = eps * max_abs_u; 313 Scalar denom = u(n - 1, 0); 314 Eigen::Matrix<Scalar, 1, Eigen::Dynamic> row = x.row(n - 1); 315 MaybePerturbPivot(tol, denom, row); 316 x.row(n - 1) = row * (one / denom); 317 if (n > 1) { 318 denom = u(n - 2, 0); 319 row = x.row(n - 2) - u(n - 2, 1) * x.row(n - 1); 320 MaybePerturbPivot(std::copysign(tol, std::real(denom)), denom, row); 321 x.row(n - 2) = row * (one / denom); 322 323 for (int k = n - 3; k >= 0; --k) { 324 row = x.row(k) - u(k, 1) * x.row(k + 1) - u(k, 2) * x.row(k + 2); 325 denom = u(k, 0); 326 MaybePerturbPivot(std::copysign(tol, std::real(denom)), denom, row); 327 x.row(k) = row * (one / denom); 328 } 329 } 330 } 331 SolveWithGaussianEliminationWithPivoting(OpKernelContext * context,const MatrixMapRow & superdiag,const MatrixMapRow & diag,const MatrixMapRow & subdiag,const ConstMatrixMap & rhs,MatrixMap & x)332 void SolveWithGaussianEliminationWithPivoting(OpKernelContext* context, 333 const MatrixMapRow& superdiag, 334 const MatrixMapRow& diag, 335 const MatrixMapRow& subdiag, 336 const ConstMatrixMap& rhs, 337 MatrixMap& x) { 338 const int n = diag.size(); 339 const Scalar zero(0); 340 341 // The three columns in u are the diagonal, superdiagonal, and second 342 // superdiagonal, respectively, of the U matrix in the LU decomposition of 343 // the input matrix (subject to row exchanges due to pivoting). For pivoted 344 // tridiagonal matrix, the U matrix has at most two non-zero superdiagonals. 345 Eigen::Array<Scalar, Eigen::Dynamic, 3> u(n, 3); 346 347 // The code below roughly follows LAPACK's dgtsv routine, with main 348 // difference being not overwriting the input. 349 u(0, 0) = diag(0); 350 u(0, 1) = superdiag(0); 351 x.row(0) = rhs.row(0); 352 for (int i = 0; i < n - 1; ++i) { 353 if (std::abs(u(i)) >= std::abs(subdiag(i + 1))) { 354 // No row interchange. 355 if (u(i) == zero) { 356 LOG(WARNING) << kNotInvertibleMsg; 357 x.fill(std::numeric_limits<Scalar>::quiet_NaN()); 358 return; 359 } 360 const Scalar factor = subdiag(i + 1) / u(i, 0); 361 u(i + 1, 0) = diag(i + 1) - factor * u(i, 1); 362 x.row(i + 1) = rhs.row(i + 1) - factor * x.row(i); 363 if (i != n - 2) { 364 u(i + 1, 1) = superdiag(i + 1); 365 u(i, 2) = 0; 366 } 367 } else { 368 // Interchange rows i and i + 1. 369 const Scalar factor = u(i, 0) / subdiag(i + 1); 370 u(i, 0) = subdiag(i + 1); 371 u(i + 1, 0) = u(i, 1) - factor * diag(i + 1); 372 u(i, 1) = diag(i + 1); 373 x.row(i + 1) = x.row(i) - factor * rhs.row(i + 1); 374 x.row(i) = rhs.row(i + 1); 375 if (i != n - 2) { 376 u(i, 2) = superdiag(i + 1); 377 u(i + 1, 1) = -factor * superdiag(i + 1); 378 } 379 } 380 } 381 if (u(n - 1, 0) == zero) { 382 LOG(WARNING) << kNotInvertibleMsg; 383 x.fill(std::numeric_limits<Scalar>::quiet_NaN()); 384 return; 385 } 386 x.row(n - 1) /= u(n - 1, 0); 387 x.row(n - 2) = (x.row(n - 2) - u(n - 2, 1) * x.row(n - 1)) / u(n - 2, 0); 388 for (int i = n - 3; i >= 0; --i) { 389 x.row(i) = (x.row(i) - u(i, 1) * x.row(i + 1) - u(i, 2) * x.row(i + 2)) / 390 u(i, 0); 391 } 392 } 393 SolveWithThomasAlgorithm(OpKernelContext * context,const MatrixMapRow & superdiag,const MatrixMapRow & diag,const MatrixMapRow & subdiag,const ConstMatrixMap & rhs,MatrixMap & x)394 void SolveWithThomasAlgorithm(OpKernelContext* context, 395 const MatrixMapRow& superdiag, 396 const MatrixMapRow& diag, 397 const MatrixMapRow& subdiag, 398 const ConstMatrixMap& rhs, MatrixMap& x) { 399 const int n = diag.size(); 400 const Scalar zero(0); 401 402 // The superdiagonal of the U matrix in the LU decomposition of the input 403 // matrix (in Thomas algorithm, the U matrix has ones on the diagonal and 404 // one superdiagonal). 405 Eigen::Matrix<Scalar, Eigen::Dynamic, 1> u(n); 406 407 if (diag(0) == zero) { 408 LOG(WARNING) << kThomasFailedMsg; 409 x.fill(std::numeric_limits<Scalar>::quiet_NaN()); 410 return; 411 } 412 413 u(0) = superdiag(0) / diag(0); 414 x.row(0) = rhs.row(0) / diag(0); 415 for (int i = 1; i < n; ++i) { 416 auto denom = diag(i) - subdiag(i) * u(i - 1); 417 if (denom == zero) { 418 LOG(WARNING) << kThomasFailedMsg; 419 x.fill(std::numeric_limits<Scalar>::quiet_NaN()); 420 return; 421 } 422 u(i) = superdiag(i) / denom; 423 x.row(i) = (rhs.row(i) - subdiag(i) * x.row(i - 1)) / denom; 424 } 425 for (int i = n - 2; i >= 0; --i) { 426 x.row(i) -= u(i) * x.row(i + 1); 427 } 428 } 429 430 bool pivoting_; 431 bool perturb_singular_; 432 }; 433 434 REGISTER_LINALG_OP_CPU("TridiagonalSolve", (TridiagonalSolveOp<float>), float); 435 REGISTER_LINALG_OP_CPU("TridiagonalSolve", (TridiagonalSolveOp<double>), 436 double); 437 REGISTER_LINALG_OP_CPU("TridiagonalSolve", (TridiagonalSolveOp<complex64>), 438 complex64); 439 REGISTER_LINALG_OP_CPU("TridiagonalSolve", (TridiagonalSolveOp<complex128>), 440 complex128); 441 } // namespace tensorflow 442