1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Perturb a `LinearOperator` with a rank `K` update.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import ops 22from tensorflow.python.framework import tensor_shape 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import linalg_ops 25from tensorflow.python.ops import math_ops 26from tensorflow.python.ops.linalg import linear_operator 27from tensorflow.python.ops.linalg import linear_operator_diag 28from tensorflow.python.ops.linalg import linear_operator_identity 29from tensorflow.python.ops.linalg import linear_operator_util 30from tensorflow.python.platform import tf_logging as logging 31from tensorflow.python.util.tf_export import tf_export 32 33__all__ = [ 34 "LinearOperatorLowRankUpdate", 35] 36 37 38@tf_export("linalg.LinearOperatorLowRankUpdate") 39class LinearOperatorLowRankUpdate(linear_operator.LinearOperator): 40 """Perturb a `LinearOperator` with a rank `K` update. 41 42 This operator acts like a [batch] matrix `A` with shape 43 `[B1,...,Bb, M, N]` for some `b >= 0`. The first `b` indices index a 44 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is 45 an `M x N` matrix. 46 47 `LinearOperatorLowRankUpdate` represents `A = L + U D V^H`, where 48 49 ``` 50 L, is a LinearOperator representing [batch] M x N matrices 51 U, is a [batch] M x K matrix. Typically K << M. 52 D, is a [batch] K x K matrix. 53 V, is a [batch] N x K matrix. Typically K << N. 54 V^H is the Hermitian transpose (adjoint) of V. 55 ``` 56 57 If `M = N`, determinants and solves are done using the matrix determinant 58 lemma and Woodbury identities, and thus require L and D to be non-singular. 59 60 Solves and determinants will be attempted unless the "is_non_singular" 61 property of L and D is False. 62 63 In the event that L and D are positive-definite, and U = V, solves and 64 determinants can be done using a Cholesky factorization. 65 66 ```python 67 # Create a 3 x 3 diagonal linear operator. 68 diag_operator = LinearOperatorDiag( 69 diag_update=[1., 2., 3.], is_non_singular=True, is_self_adjoint=True, 70 is_positive_definite=True) 71 72 # Perturb with a rank 2 perturbation 73 operator = LinearOperatorLowRankUpdate( 74 operator=diag_operator, 75 u=[[1., 2.], [-1., 3.], [0., 0.]], 76 diag_update=[11., 12.], 77 v=[[1., 2.], [-1., 3.], [10., 10.]]) 78 79 operator.shape 80 ==> [3, 3] 81 82 operator.log_abs_determinant() 83 ==> scalar Tensor 84 85 x = ... Shape [3, 4] Tensor 86 operator.matmul(x) 87 ==> Shape [3, 4] Tensor 88 ``` 89 90 ### Shape compatibility 91 92 This operator acts on [batch] matrix with compatible shape. 93 `x` is a batch matrix with compatible shape for `matmul` and `solve` if 94 95 ``` 96 operator.shape = [B1,...,Bb] + [M, N], with b >= 0 97 x.shape = [B1,...,Bb] + [N, R], with R >= 0. 98 ``` 99 100 ### Performance 101 102 Suppose `operator` is a `LinearOperatorLowRankUpdate` of shape `[M, N]`, 103 made from a rank `K` update of `base_operator` which performs `.matmul(x)` on 104 `x` having `x.shape = [N, R]` with `O(L_matmul*N*R)` complexity (and similarly 105 for `solve`, `determinant`. Then, if `x.shape = [N, R]`, 106 107 * `operator.matmul(x)` is `O(L_matmul*N*R + K*N*R)` 108 109 and if `M = N`, 110 111 * `operator.solve(x)` is `O(L_matmul*N*R + N*K*R + K^2*R + K^3)` 112 * `operator.determinant()` is `O(L_determinant + L_solve*N*K + K^2*N + K^3)` 113 114 If instead `operator` and `x` have shape `[B1,...,Bb, M, N]` and 115 `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`. 116 117 #### Matrix property hints 118 119 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 120 for `X = non_singular`, `self_adjoint`, `positive_definite`, 121 `diag_update_positive` and `square`. These have the following meaning: 122 123 * If `is_X == True`, callers should expect the operator to have the 124 property `X`. This is a promise that should be fulfilled, but is *not* a 125 runtime assert. For example, finite floating point precision may result 126 in these promises being violated. 127 * If `is_X == False`, callers should expect the operator to not have `X`. 128 * If `is_X == None` (the default), callers should have no expectation either 129 way. 130 """ 131 132 def __init__(self, 133 base_operator, 134 u, 135 diag_update=None, 136 v=None, 137 is_diag_update_positive=None, 138 is_non_singular=None, 139 is_self_adjoint=None, 140 is_positive_definite=None, 141 is_square=None, 142 name="LinearOperatorLowRankUpdate"): 143 """Initialize a `LinearOperatorLowRankUpdate`. 144 145 This creates a `LinearOperator` of the form `A = L + U D V^H`, with 146 `L` a `LinearOperator`, `U, V` both [batch] matrices, and `D` a [batch] 147 diagonal matrix. 148 149 If `L` is non-singular, solves and determinants are available. 150 Solves/determinants both involve a solve/determinant of a `K x K` system. 151 In the event that L and D are self-adjoint positive-definite, and U = V, 152 this can be done using a Cholesky factorization. The user should set the 153 `is_X` matrix property hints, which will trigger the appropriate code path. 154 155 Args: 156 base_operator: Shape `[B1,...,Bb, M, N]`. 157 u: Shape `[B1,...,Bb, M, K]` `Tensor` of same `dtype` as `base_operator`. 158 This is `U` above. 159 diag_update: Optional shape `[B1,...,Bb, K]` `Tensor` with same `dtype` 160 as `base_operator`. This is the diagonal of `D` above. 161 Defaults to `D` being the identity operator. 162 v: Optional `Tensor` of same `dtype` as `u` and shape `[B1,...,Bb, N, K]` 163 Defaults to `v = u`, in which case the perturbation is symmetric. 164 If `M != N`, then `v` must be set since the perturbation is not square. 165 is_diag_update_positive: Python `bool`. 166 If `True`, expect `diag_update > 0`. 167 is_non_singular: Expect that this operator is non-singular. 168 Default is `None`, unless `is_positive_definite` is auto-set to be 169 `True` (see below). 170 is_self_adjoint: Expect that this operator is equal to its hermitian 171 transpose. Default is `None`, unless `base_operator` is self-adjoint 172 and `v = None` (meaning `u=v`), in which case this defaults to `True`. 173 is_positive_definite: Expect that this operator is positive definite. 174 Default is `None`, unless `base_operator` is positive-definite 175 `v = None` (meaning `u=v`), and `is_diag_update_positive`, in which case 176 this defaults to `True`. 177 Note that we say an operator is positive definite when the quadratic 178 form `x^H A x` has positive real part for all nonzero `x`. 179 is_square: Expect that this operator acts like square [batch] matrices. 180 name: A name for this `LinearOperator`. 181 182 Raises: 183 ValueError: If `is_X` flags are set in an inconsistent way. 184 """ 185 dtype = base_operator.dtype 186 187 if diag_update is not None: 188 if is_diag_update_positive and dtype.is_complex: 189 logging.warn("Note: setting is_diag_update_positive with a complex " 190 "dtype means that diagonal is real and positive.") 191 192 if diag_update is None: 193 if is_diag_update_positive is False: 194 raise ValueError( 195 "Default diagonal is the identity, which is positive. However, " 196 "user set 'is_diag_update_positive' to False.") 197 is_diag_update_positive = True 198 199 # In this case, we can use a Cholesky decomposition to help us solve/det. 200 self._use_cholesky = ( 201 base_operator.is_positive_definite and base_operator.is_self_adjoint 202 and is_diag_update_positive 203 and v is None) 204 205 # Possibly auto-set some characteristic flags from None to True. 206 # If the Flags were set (by the user) incorrectly to False, then raise. 207 if base_operator.is_self_adjoint and v is None and not dtype.is_complex: 208 if is_self_adjoint is False: 209 raise ValueError( 210 "A = L + UDU^H, with L self-adjoint and D real diagonal. Since" 211 " UDU^H is self-adjoint, this must be a self-adjoint operator.") 212 is_self_adjoint = True 213 214 # The condition for using a cholesky is sufficient for SPD, and 215 # we no weaker choice of these hints leads to SPD. Therefore, 216 # the following line reads "if hints indicate SPD..." 217 if self._use_cholesky: 218 if ( 219 is_positive_definite is False 220 or is_self_adjoint is False 221 or is_non_singular is False): 222 raise ValueError( 223 "Arguments imply this is self-adjoint positive-definite operator.") 224 is_positive_definite = True 225 is_self_adjoint = True 226 227 values = base_operator.graph_parents + [u, diag_update, v] 228 with ops.name_scope(name, values=values): 229 230 # Create U and V. 231 self._u = ops.convert_to_tensor(u, name="u") 232 if v is None: 233 self._v = self._u 234 else: 235 self._v = ops.convert_to_tensor(v, name="v") 236 237 if diag_update is None: 238 self._diag_update = None 239 else: 240 self._diag_update = ops.convert_to_tensor( 241 diag_update, name="diag_update") 242 243 # Create base_operator L. 244 self._base_operator = base_operator 245 graph_parents = base_operator.graph_parents + [ 246 self.u, self._diag_update, self.v] 247 graph_parents = [p for p in graph_parents if p is not None] 248 249 super(LinearOperatorLowRankUpdate, self).__init__( 250 dtype=self._base_operator.dtype, 251 graph_parents=graph_parents, 252 is_non_singular=is_non_singular, 253 is_self_adjoint=is_self_adjoint, 254 is_positive_definite=is_positive_definite, 255 is_square=is_square, 256 name=name) 257 258 # Create the diagonal operator D. 259 self._set_diag_operators(diag_update, is_diag_update_positive) 260 self._is_diag_update_positive = is_diag_update_positive 261 262 self._check_shapes() 263 264 # Pre-compute the so-called "capacitance" matrix 265 # C := D^{-1} + V^H L^{-1} U 266 self._capacitance = self._make_capacitance() 267 if self._use_cholesky: 268 self._chol_capacitance = linalg_ops.cholesky(self._capacitance) 269 270 def _check_shapes(self): 271 """Static check that shapes are compatible.""" 272 # Broadcast shape also checks that u and v are compatible. 273 uv_shape = array_ops.broadcast_static_shape( 274 self.u.get_shape(), self.v.get_shape()) 275 276 batch_shape = array_ops.broadcast_static_shape( 277 self.base_operator.batch_shape, uv_shape[:-2]) 278 279 tensor_shape.Dimension( 280 self.base_operator.domain_dimension).assert_is_compatible_with( 281 uv_shape[-2]) 282 283 if self._diag_update is not None: 284 tensor_shape.dimension_at_index(uv_shape, -1).assert_is_compatible_with( 285 self._diag_update.get_shape()[-1]) 286 array_ops.broadcast_static_shape( 287 batch_shape, self._diag_update.get_shape()[:-1]) 288 289 def _set_diag_operators(self, diag_update, is_diag_update_positive): 290 """Set attributes self._diag_update and self._diag_operator.""" 291 if diag_update is not None: 292 self._diag_operator = linear_operator_diag.LinearOperatorDiag( 293 self._diag_update, is_positive_definite=is_diag_update_positive) 294 self._diag_inv_operator = linear_operator_diag.LinearOperatorDiag( 295 1. / self._diag_update, is_positive_definite=is_diag_update_positive) 296 else: 297 if tensor_shape.dimension_value(self.u.shape[-1]) is not None: 298 r = tensor_shape.dimension_value(self.u.shape[-1]) 299 else: 300 r = array_ops.shape(self.u)[-1] 301 self._diag_operator = linear_operator_identity.LinearOperatorIdentity( 302 num_rows=r, dtype=self.dtype) 303 self._diag_inv_operator = self._diag_operator 304 305 @property 306 def u(self): 307 """If this operator is `A = L + U D V^H`, this is the `U`.""" 308 return self._u 309 310 @property 311 def v(self): 312 """If this operator is `A = L + U D V^H`, this is the `V`.""" 313 return self._v 314 315 @property 316 def is_diag_update_positive(self): 317 """If this operator is `A = L + U D V^H`, this hints `D > 0` elementwise.""" 318 return self._is_diag_update_positive 319 320 @property 321 def diag_update(self): 322 """If this operator is `A = L + U D V^H`, this is the diagonal of `D`.""" 323 return self._diag_update 324 325 @property 326 def diag_operator(self): 327 """If this operator is `A = L + U D V^H`, this is `D`.""" 328 return self._diag_operator 329 330 @property 331 def base_operator(self): 332 """If this operator is `A = L + U D V^H`, this is the `L`.""" 333 return self._base_operator 334 335 def _shape(self): 336 batch_shape = array_ops.broadcast_static_shape( 337 self.base_operator.batch_shape, 338 self.u.get_shape()[:-2]) 339 return batch_shape.concatenate(self.base_operator.shape[-2:]) 340 341 def _shape_tensor(self): 342 batch_shape = array_ops.broadcast_dynamic_shape( 343 self.base_operator.batch_shape_tensor(), 344 array_ops.shape(self.u)[:-2]) 345 return array_ops.concat( 346 [batch_shape, self.base_operator.shape_tensor()[-2:]], axis=0) 347 348 def _matmul(self, x, adjoint=False, adjoint_arg=False): 349 u = self.u 350 v = self.v 351 l = self.base_operator 352 d = self.diag_operator 353 354 leading_term = l.matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg) 355 356 if adjoint: 357 uh_x = linear_operator_util.matmul_with_broadcast( 358 u, x, adjoint_a=True, adjoint_b=adjoint_arg) 359 d_uh_x = d.matmul(uh_x, adjoint=adjoint) 360 v_d_uh_x = linear_operator_util.matmul_with_broadcast( 361 v, d_uh_x) 362 return leading_term + v_d_uh_x 363 else: 364 vh_x = linear_operator_util.matmul_with_broadcast( 365 v, x, adjoint_a=True, adjoint_b=adjoint_arg) 366 d_vh_x = d.matmul(vh_x, adjoint=adjoint) 367 u_d_vh_x = linear_operator_util.matmul_with_broadcast(u, d_vh_x) 368 return leading_term + u_d_vh_x 369 370 def _determinant(self): 371 if self.is_positive_definite: 372 return math_ops.exp(self.log_abs_determinant()) 373 # The matrix determinant lemma gives 374 # https://en.wikipedia.org/wiki/Matrix_determinant_lemma 375 # det(L + UDV^H) = det(D^{-1} + V^H L^{-1} U) det(D) det(L) 376 # = det(C) det(D) det(L) 377 # where C is sometimes known as the capacitance matrix, 378 # C := D^{-1} + V^H L^{-1} U 379 det_c = linalg_ops.matrix_determinant(self._capacitance) 380 det_d = self.diag_operator.determinant() 381 det_l = self.base_operator.determinant() 382 return det_c * det_d * det_l 383 384 def _log_abs_determinant(self): 385 # Recall 386 # det(L + UDV^H) = det(D^{-1} + V^H L^{-1} U) det(D) det(L) 387 # = det(C) det(D) det(L) 388 log_abs_det_d = self.diag_operator.log_abs_determinant() 389 log_abs_det_l = self.base_operator.log_abs_determinant() 390 391 if self._use_cholesky: 392 chol_cap_diag = array_ops.matrix_diag_part(self._chol_capacitance) 393 log_abs_det_c = 2 * math_ops.reduce_sum( 394 math_ops.log(chol_cap_diag), axis=[-1]) 395 else: 396 det_c = linalg_ops.matrix_determinant(self._capacitance) 397 log_abs_det_c = math_ops.log(math_ops.abs(det_c)) 398 if self.dtype.is_complex: 399 log_abs_det_c = math_ops.cast(log_abs_det_c, dtype=self.dtype) 400 401 return log_abs_det_c + log_abs_det_d + log_abs_det_l 402 403 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 404 if self.base_operator.is_non_singular is False: 405 raise ValueError( 406 "Solve not implemented unless this is a perturbation of a " 407 "non-singular LinearOperator.") 408 # The Woodbury formula gives: 409 # https://en.wikipedia.org/wiki/Woodbury_matrix_identity 410 # (L + UDV^H)^{-1} 411 # = L^{-1} - L^{-1} U (D^{-1} + V^H L^{-1} U)^{-1} V^H L^{-1} 412 # = L^{-1} - L^{-1} U C^{-1} V^H L^{-1} 413 # where C is the capacitance matrix, C := D^{-1} + V^H L^{-1} U 414 # Note also that, with ^{-H} being the inverse of the adjoint, 415 # (L + UDV^H)^{-H} 416 # = L^{-H} - L^{-H} V C^{-H} U^H L^{-H} 417 l = self.base_operator 418 if adjoint: 419 v = self.u 420 u = self.v 421 else: 422 v = self.v 423 u = self.u 424 425 # L^{-1} rhs 426 linv_rhs = l.solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg) 427 # V^H L^{-1} rhs 428 vh_linv_rhs = linear_operator_util.matmul_with_broadcast( 429 v, linv_rhs, adjoint_a=True) 430 # C^{-1} V^H L^{-1} rhs 431 if self._use_cholesky: 432 capinv_vh_linv_rhs = linear_operator_util.cholesky_solve_with_broadcast( 433 self._chol_capacitance, vh_linv_rhs) 434 else: 435 capinv_vh_linv_rhs = linear_operator_util.matrix_solve_with_broadcast( 436 self._capacitance, vh_linv_rhs, adjoint=adjoint) 437 # U C^{-1} V^H M^{-1} rhs 438 u_capinv_vh_linv_rhs = linear_operator_util.matmul_with_broadcast( 439 u, capinv_vh_linv_rhs) 440 # L^{-1} U C^{-1} V^H L^{-1} rhs 441 linv_u_capinv_vh_linv_rhs = l.solve(u_capinv_vh_linv_rhs, adjoint=adjoint) 442 443 # L^{-1} - L^{-1} U C^{-1} V^H L^{-1} 444 return linv_rhs - linv_u_capinv_vh_linv_rhs 445 446 def _make_capacitance(self): 447 # C := D^{-1} + V^H L^{-1} U 448 # which is sometimes known as the "capacitance" matrix. 449 450 # L^{-1} U 451 linv_u = self.base_operator.solve(self.u) 452 # V^H L^{-1} U 453 vh_linv_u = linear_operator_util.matmul_with_broadcast( 454 self.v, linv_u, adjoint_a=True) 455 456 # D^{-1} + V^H L^{-1} V 457 capacitance = self._diag_inv_operator.add_to_tensor(vh_linv_u) 458 return capacitance 459