• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2011 Jitse Niesen <jitse@maths.leeds.ac.uk>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_MATRIX_SQUARE_ROOT
11 #define EIGEN_MATRIX_SQUARE_ROOT
12 
13 namespace Eigen {
14 
15 /** \ingroup MatrixFunctions_Module
16   * \brief Class for computing matrix square roots of upper quasi-triangular matrices.
17   * \tparam  MatrixType  type of the argument of the matrix square root,
18   *                      expected to be an instantiation of the Matrix class template.
19   *
20   * This class computes the square root of the upper quasi-triangular
21   * matrix stored in the upper Hessenberg part of the matrix passed to
22   * the constructor.
23   *
24   * \sa MatrixSquareRoot, MatrixSquareRootTriangular
25   */
26 template <typename MatrixType>
27 class MatrixSquareRootQuasiTriangular
28 {
29   public:
30 
31     /** \brief Constructor.
32       *
33       * \param[in]  A  upper quasi-triangular matrix whose square root
34       *                is to be computed.
35       *
36       * The class stores a reference to \p A, so it should not be
37       * changed (or destroyed) before compute() is called.
38       */
MatrixSquareRootQuasiTriangular(const MatrixType & A)39     MatrixSquareRootQuasiTriangular(const MatrixType& A)
40       : m_A(A)
41     {
42       eigen_assert(A.rows() == A.cols());
43     }
44 
45     /** \brief Compute the matrix square root
46       *
47       * \param[out] result  square root of \p A, as specified in the constructor.
48       *
49       * Only the upper Hessenberg part of \p result is updated, the
50       * rest is not touched.  See MatrixBase::sqrt() for details on
51       * how this computation is implemented.
52       */
53     template <typename ResultType> void compute(ResultType &result);
54 
55   private:
56     typedef typename MatrixType::Index Index;
57     typedef typename MatrixType::Scalar Scalar;
58 
59     void computeDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T);
60     void computeOffDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T);
61     void compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i);
62     void compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
63 				  typename MatrixType::Index i, typename MatrixType::Index j);
64     void compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
65 				  typename MatrixType::Index i, typename MatrixType::Index j);
66     void compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
67 				  typename MatrixType::Index i, typename MatrixType::Index j);
68     void compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
69 				  typename MatrixType::Index i, typename MatrixType::Index j);
70 
71     template <typename SmallMatrixType>
72     static void solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A,
73 				     const SmallMatrixType& B, const SmallMatrixType& C);
74 
75     const MatrixType& m_A;
76 };
77 
78 template <typename MatrixType>
79 template <typename ResultType>
compute(ResultType & result)80 void MatrixSquareRootQuasiTriangular<MatrixType>::compute(ResultType &result)
81 {
82   result.resize(m_A.rows(), m_A.cols());
83   computeDiagonalPartOfSqrt(result, m_A);
84   computeOffDiagonalPartOfSqrt(result, m_A);
85 }
86 
87 // pre:  T is quasi-upper-triangular and sqrtT is a zero matrix of the same size
88 // post: the diagonal blocks of sqrtT are the square roots of the diagonal blocks of T
89 template <typename MatrixType>
computeDiagonalPartOfSqrt(MatrixType & sqrtT,const MatrixType & T)90 void MatrixSquareRootQuasiTriangular<MatrixType>::computeDiagonalPartOfSqrt(MatrixType& sqrtT,
91 									  const MatrixType& T)
92 {
93   using std::sqrt;
94   const Index size = m_A.rows();
95   for (Index i = 0; i < size; i++) {
96     if (i == size - 1 || T.coeff(i+1, i) == 0) {
97       eigen_assert(T(i,i) >= 0);
98       sqrtT.coeffRef(i,i) = sqrt(T.coeff(i,i));
99     }
100     else {
101       compute2x2diagonalBlock(sqrtT, T, i);
102       ++i;
103     }
104   }
105 }
106 
107 // pre:  T is quasi-upper-triangular and diagonal blocks of sqrtT are square root of diagonal blocks of T.
108 // post: sqrtT is the square root of T.
109 template <typename MatrixType>
computeOffDiagonalPartOfSqrt(MatrixType & sqrtT,const MatrixType & T)110 void MatrixSquareRootQuasiTriangular<MatrixType>::computeOffDiagonalPartOfSqrt(MatrixType& sqrtT,
111 									     const MatrixType& T)
112 {
113   const Index size = m_A.rows();
114   for (Index j = 1; j < size; j++) {
115       if (T.coeff(j, j-1) != 0)  // if T(j-1:j, j-1:j) is a 2-by-2 block
116 	continue;
117     for (Index i = j-1; i >= 0; i--) {
118       if (i > 0 && T.coeff(i, i-1) != 0)  // if T(i-1:i, i-1:i) is a 2-by-2 block
119 	continue;
120       bool iBlockIs2x2 = (i < size - 1) && (T.coeff(i+1, i) != 0);
121       bool jBlockIs2x2 = (j < size - 1) && (T.coeff(j+1, j) != 0);
122       if (iBlockIs2x2 && jBlockIs2x2)
123 	compute2x2offDiagonalBlock(sqrtT, T, i, j);
124       else if (iBlockIs2x2 && !jBlockIs2x2)
125 	compute2x1offDiagonalBlock(sqrtT, T, i, j);
126       else if (!iBlockIs2x2 && jBlockIs2x2)
127 	compute1x2offDiagonalBlock(sqrtT, T, i, j);
128       else if (!iBlockIs2x2 && !jBlockIs2x2)
129 	compute1x1offDiagonalBlock(sqrtT, T, i, j);
130     }
131   }
132 }
133 
134 // pre:  T.block(i,i,2,2) has complex conjugate eigenvalues
135 // post: sqrtT.block(i,i,2,2) is square root of T.block(i,i,2,2)
136 template <typename MatrixType>
137 void MatrixSquareRootQuasiTriangular<MatrixType>
compute2x2diagonalBlock(MatrixType & sqrtT,const MatrixType & T,typename MatrixType::Index i)138      ::compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i)
139 {
140   // TODO: This case (2-by-2 blocks with complex conjugate eigenvalues) is probably hidden somewhere
141   //       in EigenSolver. If we expose it, we could call it directly from here.
142   Matrix<Scalar,2,2> block = T.template block<2,2>(i,i);
143   EigenSolver<Matrix<Scalar,2,2> > es(block);
144   sqrtT.template block<2,2>(i,i)
145     = (es.eigenvectors() * es.eigenvalues().cwiseSqrt().asDiagonal() * es.eigenvectors().inverse()).real();
146 }
147 
148 // pre:  block structure of T is such that (i,j) is a 1x1 block,
149 //       all blocks of sqrtT to left of and below (i,j) are correct
150 // post: sqrtT(i,j) has the correct value
151 template <typename MatrixType>
152 void MatrixSquareRootQuasiTriangular<MatrixType>
compute1x1offDiagonalBlock(MatrixType & sqrtT,const MatrixType & T,typename MatrixType::Index i,typename MatrixType::Index j)153      ::compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
154 				  typename MatrixType::Index i, typename MatrixType::Index j)
155 {
156   Scalar tmp = (sqrtT.row(i).segment(i+1,j-i-1) * sqrtT.col(j).segment(i+1,j-i-1)).value();
157   sqrtT.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (sqrtT.coeff(i,i) + sqrtT.coeff(j,j));
158 }
159 
160 // similar to compute1x1offDiagonalBlock()
161 template <typename MatrixType>
162 void MatrixSquareRootQuasiTriangular<MatrixType>
compute1x2offDiagonalBlock(MatrixType & sqrtT,const MatrixType & T,typename MatrixType::Index i,typename MatrixType::Index j)163      ::compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
164 				  typename MatrixType::Index i, typename MatrixType::Index j)
165 {
166   Matrix<Scalar,1,2> rhs = T.template block<1,2>(i,j);
167   if (j-i > 1)
168     rhs -= sqrtT.block(i, i+1, 1, j-i-1) * sqrtT.block(i+1, j, j-i-1, 2);
169   Matrix<Scalar,2,2> A = sqrtT.coeff(i,i) * Matrix<Scalar,2,2>::Identity();
170   A += sqrtT.template block<2,2>(j,j).transpose();
171   sqrtT.template block<1,2>(i,j).transpose() = A.fullPivLu().solve(rhs.transpose());
172 }
173 
174 // similar to compute1x1offDiagonalBlock()
175 template <typename MatrixType>
176 void MatrixSquareRootQuasiTriangular<MatrixType>
compute2x1offDiagonalBlock(MatrixType & sqrtT,const MatrixType & T,typename MatrixType::Index i,typename MatrixType::Index j)177      ::compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
178 				  typename MatrixType::Index i, typename MatrixType::Index j)
179 {
180   Matrix<Scalar,2,1> rhs = T.template block<2,1>(i,j);
181   if (j-i > 2)
182     rhs -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 1);
183   Matrix<Scalar,2,2> A = sqrtT.coeff(j,j) * Matrix<Scalar,2,2>::Identity();
184   A += sqrtT.template block<2,2>(i,i);
185   sqrtT.template block<2,1>(i,j) = A.fullPivLu().solve(rhs);
186 }
187 
188 // similar to compute1x1offDiagonalBlock()
189 template <typename MatrixType>
190 void MatrixSquareRootQuasiTriangular<MatrixType>
compute2x2offDiagonalBlock(MatrixType & sqrtT,const MatrixType & T,typename MatrixType::Index i,typename MatrixType::Index j)191      ::compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
192 				  typename MatrixType::Index i, typename MatrixType::Index j)
193 {
194   Matrix<Scalar,2,2> A = sqrtT.template block<2,2>(i,i);
195   Matrix<Scalar,2,2> B = sqrtT.template block<2,2>(j,j);
196   Matrix<Scalar,2,2> C = T.template block<2,2>(i,j);
197   if (j-i > 2)
198     C -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 2);
199   Matrix<Scalar,2,2> X;
200   solveAuxiliaryEquation(X, A, B, C);
201   sqrtT.template block<2,2>(i,j) = X;
202 }
203 
204 // solves the equation A X + X B = C where all matrices are 2-by-2
205 template <typename MatrixType>
206 template <typename SmallMatrixType>
207 void MatrixSquareRootQuasiTriangular<MatrixType>
solveAuxiliaryEquation(SmallMatrixType & X,const SmallMatrixType & A,const SmallMatrixType & B,const SmallMatrixType & C)208      ::solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A,
209 			      const SmallMatrixType& B, const SmallMatrixType& C)
210 {
211   EIGEN_STATIC_ASSERT((internal::is_same<SmallMatrixType, Matrix<Scalar,2,2> >::value),
212 		      EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT);
213 
214   Matrix<Scalar,4,4> coeffMatrix = Matrix<Scalar,4,4>::Zero();
215   coeffMatrix.coeffRef(0,0) = A.coeff(0,0) + B.coeff(0,0);
216   coeffMatrix.coeffRef(1,1) = A.coeff(0,0) + B.coeff(1,1);
217   coeffMatrix.coeffRef(2,2) = A.coeff(1,1) + B.coeff(0,0);
218   coeffMatrix.coeffRef(3,3) = A.coeff(1,1) + B.coeff(1,1);
219   coeffMatrix.coeffRef(0,1) = B.coeff(1,0);
220   coeffMatrix.coeffRef(0,2) = A.coeff(0,1);
221   coeffMatrix.coeffRef(1,0) = B.coeff(0,1);
222   coeffMatrix.coeffRef(1,3) = A.coeff(0,1);
223   coeffMatrix.coeffRef(2,0) = A.coeff(1,0);
224   coeffMatrix.coeffRef(2,3) = B.coeff(1,0);
225   coeffMatrix.coeffRef(3,1) = A.coeff(1,0);
226   coeffMatrix.coeffRef(3,2) = B.coeff(0,1);
227 
228   Matrix<Scalar,4,1> rhs;
229   rhs.coeffRef(0) = C.coeff(0,0);
230   rhs.coeffRef(1) = C.coeff(0,1);
231   rhs.coeffRef(2) = C.coeff(1,0);
232   rhs.coeffRef(3) = C.coeff(1,1);
233 
234   Matrix<Scalar,4,1> result;
235   result = coeffMatrix.fullPivLu().solve(rhs);
236 
237   X.coeffRef(0,0) = result.coeff(0);
238   X.coeffRef(0,1) = result.coeff(1);
239   X.coeffRef(1,0) = result.coeff(2);
240   X.coeffRef(1,1) = result.coeff(3);
241 }
242 
243 
244 /** \ingroup MatrixFunctions_Module
245   * \brief Class for computing matrix square roots of upper triangular matrices.
246   * \tparam  MatrixType  type of the argument of the matrix square root,
247   *                      expected to be an instantiation of the Matrix class template.
248   *
249   * This class computes the square root of the upper triangular matrix
250   * stored in the upper triangular part (including the diagonal) of
251   * the matrix passed to the constructor.
252   *
253   * \sa MatrixSquareRoot, MatrixSquareRootQuasiTriangular
254   */
255 template <typename MatrixType>
256 class MatrixSquareRootTriangular
257 {
258   public:
MatrixSquareRootTriangular(const MatrixType & A)259     MatrixSquareRootTriangular(const MatrixType& A)
260       : m_A(A)
261     {
262       eigen_assert(A.rows() == A.cols());
263     }
264 
265     /** \brief Compute the matrix square root
266       *
267       * \param[out] result  square root of \p A, as specified in the constructor.
268       *
269       * Only the upper triangular part (including the diagonal) of
270       * \p result is updated, the rest is not touched.  See
271       * MatrixBase::sqrt() for details on how this computation is
272       * implemented.
273       */
274     template <typename ResultType> void compute(ResultType &result);
275 
276  private:
277     const MatrixType& m_A;
278 };
279 
280 template <typename MatrixType>
281 template <typename ResultType>
compute(ResultType & result)282 void MatrixSquareRootTriangular<MatrixType>::compute(ResultType &result)
283 {
284   using std::sqrt;
285 
286   // Compute square root of m_A and store it in upper triangular part of result
287   // This uses that the square root of triangular matrices can be computed directly.
288   result.resize(m_A.rows(), m_A.cols());
289   typedef typename MatrixType::Index Index;
290   for (Index i = 0; i < m_A.rows(); i++) {
291     result.coeffRef(i,i) = sqrt(m_A.coeff(i,i));
292   }
293   for (Index j = 1; j < m_A.cols(); j++) {
294     for (Index i = j-1; i >= 0; i--) {
295       typedef typename MatrixType::Scalar Scalar;
296       // if i = j-1, then segment has length 0 so tmp = 0
297       Scalar tmp = (result.row(i).segment(i+1,j-i-1) * result.col(j).segment(i+1,j-i-1)).value();
298       // denominator may be zero if original matrix is singular
299       result.coeffRef(i,j) = (m_A.coeff(i,j) - tmp) / (result.coeff(i,i) + result.coeff(j,j));
300     }
301   }
302 }
303 
304 
305 /** \ingroup MatrixFunctions_Module
306   * \brief Class for computing matrix square roots of general matrices.
307   * \tparam  MatrixType  type of the argument of the matrix square root,
308   *                      expected to be an instantiation of the Matrix class template.
309   *
310   * \sa MatrixSquareRootTriangular, MatrixSquareRootQuasiTriangular, MatrixBase::sqrt()
311   */
312 template <typename MatrixType, int IsComplex = NumTraits<typename internal::traits<MatrixType>::Scalar>::IsComplex>
313 class MatrixSquareRoot
314 {
315   public:
316 
317     /** \brief Constructor.
318       *
319       * \param[in]  A  matrix whose square root is to be computed.
320       *
321       * The class stores a reference to \p A, so it should not be
322       * changed (or destroyed) before compute() is called.
323       */
324     MatrixSquareRoot(const MatrixType& A);
325 
326     /** \brief Compute the matrix square root
327       *
328       * \param[out] result  square root of \p A, as specified in the constructor.
329       *
330       * See MatrixBase::sqrt() for details on how this computation is
331       * implemented.
332       */
333     template <typename ResultType> void compute(ResultType &result);
334 };
335 
336 
337 // ********** Partial specialization for real matrices **********
338 
339 template <typename MatrixType>
340 class MatrixSquareRoot<MatrixType, 0>
341 {
342   public:
343 
MatrixSquareRoot(const MatrixType & A)344     MatrixSquareRoot(const MatrixType& A)
345       : m_A(A)
346     {
347       eigen_assert(A.rows() == A.cols());
348     }
349 
compute(ResultType & result)350     template <typename ResultType> void compute(ResultType &result)
351     {
352       // Compute Schur decomposition of m_A
353       const RealSchur<MatrixType> schurOfA(m_A);
354       const MatrixType& T = schurOfA.matrixT();
355       const MatrixType& U = schurOfA.matrixU();
356 
357       // Compute square root of T
358       MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.cols());
359       MatrixSquareRootQuasiTriangular<MatrixType>(T).compute(sqrtT);
360 
361       // Compute square root of m_A
362       result = U * sqrtT * U.adjoint();
363     }
364 
365   private:
366     const MatrixType& m_A;
367 };
368 
369 
370 // ********** Partial specialization for complex matrices **********
371 
372 template <typename MatrixType>
373 class MatrixSquareRoot<MatrixType, 1>
374 {
375   public:
376 
MatrixSquareRoot(const MatrixType & A)377     MatrixSquareRoot(const MatrixType& A)
378       : m_A(A)
379     {
380       eigen_assert(A.rows() == A.cols());
381     }
382 
compute(ResultType & result)383     template <typename ResultType> void compute(ResultType &result)
384     {
385       // Compute Schur decomposition of m_A
386       const ComplexSchur<MatrixType> schurOfA(m_A);
387       const MatrixType& T = schurOfA.matrixT();
388       const MatrixType& U = schurOfA.matrixU();
389 
390       // Compute square root of T
391       MatrixType sqrtT;
392       MatrixSquareRootTriangular<MatrixType>(T).compute(sqrtT);
393 
394       // Compute square root of m_A
395       result = U * (sqrtT.template triangularView<Upper>() * U.adjoint());
396     }
397 
398   private:
399     const MatrixType& m_A;
400 };
401 
402 
403 /** \ingroup MatrixFunctions_Module
404   *
405   * \brief Proxy for the matrix square root of some matrix (expression).
406   *
407   * \tparam Derived  Type of the argument to the matrix square root.
408   *
409   * This class holds the argument to the matrix square root until it
410   * is assigned or evaluated for some other reason (so the argument
411   * should not be changed in the meantime). It is the return type of
412   * MatrixBase::sqrt() and most of the time this is the only way it is
413   * used.
414   */
415 template<typename Derived> class MatrixSquareRootReturnValue
416 : public ReturnByValue<MatrixSquareRootReturnValue<Derived> >
417 {
418     typedef typename Derived::Index Index;
419   public:
420     /** \brief Constructor.
421       *
422       * \param[in]  src  %Matrix (expression) forming the argument of the
423       * matrix square root.
424       */
MatrixSquareRootReturnValue(const Derived & src)425     MatrixSquareRootReturnValue(const Derived& src) : m_src(src) { }
426 
427     /** \brief Compute the matrix square root.
428       *
429       * \param[out]  result  the matrix square root of \p src in the
430       * constructor.
431       */
432     template <typename ResultType>
evalTo(ResultType & result)433     inline void evalTo(ResultType& result) const
434     {
435       const typename Derived::PlainObject srcEvaluated = m_src.eval();
436       MatrixSquareRoot<typename Derived::PlainObject> me(srcEvaluated);
437       me.compute(result);
438     }
439 
rows()440     Index rows() const { return m_src.rows(); }
cols()441     Index cols() const { return m_src.cols(); }
442 
443   protected:
444     const Derived& m_src;
445   private:
446     MatrixSquareRootReturnValue& operator=(const MatrixSquareRootReturnValue&);
447 };
448 
449 namespace internal {
450 template<typename Derived>
451 struct traits<MatrixSquareRootReturnValue<Derived> >
452 {
453   typedef typename Derived::PlainObject ReturnType;
454 };
455 }
456 
457 template <typename Derived>
458 const MatrixSquareRootReturnValue<Derived> MatrixBase<Derived>::sqrt() const
459 {
460   eigen_assert(rows() == cols());
461   return MatrixSquareRootReturnValue<Derived>(derived());
462 }
463 
464 } // end namespace Eigen
465 
466 #endif // EIGEN_MATRIX_FUNCTION
467