1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Matrix functions contains iterative methods for M^p.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.framework import dtypes 21from tensorflow.python.ops import control_flow_ops 22from tensorflow.python.ops import linalg_ops 23from tensorflow.python.ops import math_ops 24 25 26def matrix_square_root(mat_a, mat_a_size, iter_count=100, ridge_epsilon=1e-4): 27 """Iterative method to get matrix square root. 28 29 Stable iterations for the matrix square root, Nicholas J. Higham 30 31 Page 231, Eq 2.6b 32 http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.6.8799&rep=rep1&type=pdf 33 34 Args: 35 mat_a: the symmetric PSD matrix whose matrix square root be computed 36 mat_a_size: size of mat_a. 37 iter_count: Maximum number of iterations. 38 ridge_epsilon: Ridge epsilon added to make the matrix positive definite. 39 40 Returns: 41 mat_a^0.5 42 """ 43 44 def _iter_condition(i, unused_mat_y, unused_old_mat_y, unused_mat_z, 45 unused_old_mat_z, err, old_err): 46 # This method require that we check for divergence every step. 47 return math_ops.logical_and(i < iter_count, err < old_err) 48 49 def _iter_body(i, mat_y, unused_old_mat_y, mat_z, unused_old_mat_z, err, 50 unused_old_err): 51 current_iterate = 0.5 * (3.0 * identity - math_ops.matmul(mat_z, mat_y)) 52 current_mat_y = math_ops.matmul(mat_y, current_iterate) 53 current_mat_z = math_ops.matmul(current_iterate, mat_z) 54 # Compute the error in approximation. 55 mat_sqrt_a = current_mat_y * math_ops.sqrt(norm) 56 mat_a_approx = math_ops.matmul(mat_sqrt_a, mat_sqrt_a) 57 residual = mat_a - mat_a_approx 58 current_err = math_ops.sqrt(math_ops.reduce_sum(residual * residual)) / norm 59 return i + 1, current_mat_y, mat_y, current_mat_z, mat_z, current_err, err 60 61 identity = linalg_ops.eye(math_ops.cast(mat_a_size, dtypes.int32)) 62 mat_a = mat_a + ridge_epsilon * identity 63 norm = math_ops.sqrt(math_ops.reduce_sum(mat_a * mat_a)) 64 mat_init_y = mat_a / norm 65 mat_init_z = identity 66 init_err = norm 67 68 _, _, prev_mat_y, _, _, _, _ = control_flow_ops.while_loop( 69 _iter_condition, _iter_body, [ 70 0, mat_init_y, mat_init_y, mat_init_z, mat_init_z, init_err, 71 init_err + 1.0 72 ]) 73 return prev_mat_y * math_ops.sqrt(norm) 74 75 76def matrix_inverse_pth_root(mat_g, 77 mat_g_size, 78 alpha, 79 iter_count=100, 80 epsilon=1e-6, 81 ridge_epsilon=1e-6): 82 """Computes mat_g^alpha, where alpha = -1/p, p a positive integer. 83 84 We use an iterative Schur-Newton method from equation 3.2 on page 9 of: 85 86 A Schur-Newton Method for the Matrix p-th Root and its Inverse 87 by Chun-Hua Guo and Nicholas J. Higham 88 SIAM Journal on Matrix Analysis and Applications, 89 2006, Vol. 28, No. 3 : pp. 788-804 90 https://pdfs.semanticscholar.org/0abe/7f77433cf5908bfe2b79aa91af881da83858.pdf 91 92 Args: 93 mat_g: the symmetric PSD matrix whose power it to be computed 94 mat_g_size: size of mat_g. 95 alpha: exponent, must be -1/p for p a positive integer. 96 iter_count: Maximum number of iterations. 97 epsilon: accuracy indicator, useful for early termination. 98 ridge_epsilon: Ridge epsilon added to make the matrix positive definite. 99 100 Returns: 101 mat_g^alpha 102 """ 103 104 identity = linalg_ops.eye(math_ops.cast(mat_g_size, dtypes.int32)) 105 106 def mat_power(mat_m, p): 107 """Computes mat_m^p, for p a positive integer. 108 109 Power p is known at graph compile time, so no need for loop and cond. 110 Args: 111 mat_m: a square matrix 112 p: a positive integer 113 114 Returns: 115 mat_m^p 116 """ 117 assert p == int(p) and p > 0 118 power = None 119 while p > 0: 120 if p % 2 == 1: 121 power = math_ops.matmul(mat_m, power) if power is not None else mat_m 122 p //= 2 123 mat_m = math_ops.matmul(mat_m, mat_m) 124 return power 125 126 def _iter_condition(i, mat_m, _): 127 return math_ops.logical_and( 128 i < iter_count, 129 math_ops.reduce_max(math_ops.abs(mat_m - identity)) > epsilon) 130 131 def _iter_body(i, mat_m, mat_x): 132 mat_m_i = (1 - alpha) * identity + alpha * mat_m 133 return (i + 1, math_ops.matmul(mat_power(mat_m_i, -1.0 / alpha), mat_m), 134 math_ops.matmul(mat_x, mat_m_i)) 135 136 if mat_g_size == 1: 137 mat_h = math_ops.pow(mat_g + ridge_epsilon, alpha) 138 else: 139 damped_mat_g = mat_g + ridge_epsilon * identity 140 z = (1 - 1 / alpha) / (2 * linalg_ops.norm(damped_mat_g)) 141 # The best value for z is 142 # (1 - 1/alpha) * (c_max^{-alpha} - c_min^{-alpha}) / 143 # (c_max^{1-alpha} - c_min^{1-alpha}) 144 # where c_max and c_min are the largest and smallest singular values of 145 # damped_mat_g. 146 # The above estimate assumes that c_max > c_min * 2^p. (p = -1/alpha) 147 # Can replace above line by the one below, but it is less accurate, 148 # hence needs more iterations to converge. 149 # z = (1 - 1/alpha) / math_ops.trace(damped_mat_g) 150 # If we want the method to always converge, use z = 1 / norm(damped_mat_g) 151 # or z = 1 / math_ops.trace(damped_mat_g), but these can result in many 152 # extra iterations. 153 _, _, mat_h = control_flow_ops.while_loop( 154 _iter_condition, _iter_body, 155 [0, damped_mat_g * z, identity * math_ops.pow(z, -alpha)]) 156 return mat_h 157