1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Perturb a `LinearOperator` with a rank `K` update.""" 16 17from tensorflow.python.framework import ops 18from tensorflow.python.framework import tensor_shape 19from tensorflow.python.ops import array_ops 20from tensorflow.python.ops import linalg_ops 21from tensorflow.python.ops import math_ops 22from tensorflow.python.ops.linalg import linear_operator 23from tensorflow.python.ops.linalg import linear_operator_diag 24from tensorflow.python.ops.linalg import linear_operator_identity 25from tensorflow.python.ops.linalg import linear_operator_util 26from tensorflow.python.platform import tf_logging as logging 27from tensorflow.python.util.tf_export import tf_export 28 29__all__ = [ 30 "LinearOperatorLowRankUpdate", 31] 32 33 34@tf_export("linalg.LinearOperatorLowRankUpdate") 35@linear_operator.make_composite_tensor 36class LinearOperatorLowRankUpdate(linear_operator.LinearOperator): 37 """Perturb a `LinearOperator` with a rank `K` update. 38 39 This operator acts like a [batch] matrix `A` with shape 40 `[B1,...,Bb, M, N]` for some `b >= 0`. The first `b` indices index a 41 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is 42 an `M x N` matrix. 43 44 `LinearOperatorLowRankUpdate` represents `A = L + U D V^H`, where 45 46 ``` 47 L, is a LinearOperator representing [batch] M x N matrices 48 U, is a [batch] M x K matrix. Typically K << M. 49 D, is a [batch] K x K matrix. 50 V, is a [batch] N x K matrix. Typically K << N. 51 V^H is the Hermitian transpose (adjoint) of V. 52 ``` 53 54 If `M = N`, determinants and solves are done using the matrix determinant 55 lemma and Woodbury identities, and thus require L and D to be non-singular. 56 57 Solves and determinants will be attempted unless the "is_non_singular" 58 property of L and D is False. 59 60 In the event that L and D are positive-definite, and U = V, solves and 61 determinants can be done using a Cholesky factorization. 62 63 ```python 64 # Create a 3 x 3 diagonal linear operator. 65 diag_operator = LinearOperatorDiag( 66 diag_update=[1., 2., 3.], is_non_singular=True, is_self_adjoint=True, 67 is_positive_definite=True) 68 69 # Perturb with a rank 2 perturbation 70 operator = LinearOperatorLowRankUpdate( 71 operator=diag_operator, 72 u=[[1., 2.], [-1., 3.], [0., 0.]], 73 diag_update=[11., 12.], 74 v=[[1., 2.], [-1., 3.], [10., 10.]]) 75 76 operator.shape 77 ==> [3, 3] 78 79 operator.log_abs_determinant() 80 ==> scalar Tensor 81 82 x = ... Shape [3, 4] Tensor 83 operator.matmul(x) 84 ==> Shape [3, 4] Tensor 85 ``` 86 87 ### Shape compatibility 88 89 This operator acts on [batch] matrix with compatible shape. 90 `x` is a batch matrix with compatible shape for `matmul` and `solve` if 91 92 ``` 93 operator.shape = [B1,...,Bb] + [M, N], with b >= 0 94 x.shape = [B1,...,Bb] + [N, R], with R >= 0. 95 ``` 96 97 ### Performance 98 99 Suppose `operator` is a `LinearOperatorLowRankUpdate` of shape `[M, N]`, 100 made from a rank `K` update of `base_operator` which performs `.matmul(x)` on 101 `x` having `x.shape = [N, R]` with `O(L_matmul*N*R)` complexity (and similarly 102 for `solve`, `determinant`. Then, if `x.shape = [N, R]`, 103 104 * `operator.matmul(x)` is `O(L_matmul*N*R + K*N*R)` 105 106 and if `M = N`, 107 108 * `operator.solve(x)` is `O(L_matmul*N*R + N*K*R + K^2*R + K^3)` 109 * `operator.determinant()` is `O(L_determinant + L_solve*N*K + K^2*N + K^3)` 110 111 If instead `operator` and `x` have shape `[B1,...,Bb, M, N]` and 112 `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`. 113 114 #### Matrix property hints 115 116 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 117 for `X = non_singular`, `self_adjoint`, `positive_definite`, 118 `diag_update_positive` and `square`. These have the following meaning: 119 120 * If `is_X == True`, callers should expect the operator to have the 121 property `X`. This is a promise that should be fulfilled, but is *not* a 122 runtime assert. For example, finite floating point precision may result 123 in these promises being violated. 124 * If `is_X == False`, callers should expect the operator to not have `X`. 125 * If `is_X == None` (the default), callers should have no expectation either 126 way. 127 """ 128 129 def __init__(self, 130 base_operator, 131 u, 132 diag_update=None, 133 v=None, 134 is_diag_update_positive=None, 135 is_non_singular=None, 136 is_self_adjoint=None, 137 is_positive_definite=None, 138 is_square=None, 139 name="LinearOperatorLowRankUpdate"): 140 """Initialize a `LinearOperatorLowRankUpdate`. 141 142 This creates a `LinearOperator` of the form `A = L + U D V^H`, with 143 `L` a `LinearOperator`, `U, V` both [batch] matrices, and `D` a [batch] 144 diagonal matrix. 145 146 If `L` is non-singular, solves and determinants are available. 147 Solves/determinants both involve a solve/determinant of a `K x K` system. 148 In the event that L and D are self-adjoint positive-definite, and U = V, 149 this can be done using a Cholesky factorization. The user should set the 150 `is_X` matrix property hints, which will trigger the appropriate code path. 151 152 Args: 153 base_operator: Shape `[B1,...,Bb, M, N]`. 154 u: Shape `[B1,...,Bb, M, K]` `Tensor` of same `dtype` as `base_operator`. 155 This is `U` above. 156 diag_update: Optional shape `[B1,...,Bb, K]` `Tensor` with same `dtype` 157 as `base_operator`. This is the diagonal of `D` above. 158 Defaults to `D` being the identity operator. 159 v: Optional `Tensor` of same `dtype` as `u` and shape `[B1,...,Bb, N, K]` 160 Defaults to `v = u`, in which case the perturbation is symmetric. 161 If `M != N`, then `v` must be set since the perturbation is not square. 162 is_diag_update_positive: Python `bool`. 163 If `True`, expect `diag_update > 0`. 164 is_non_singular: Expect that this operator is non-singular. 165 Default is `None`, unless `is_positive_definite` is auto-set to be 166 `True` (see below). 167 is_self_adjoint: Expect that this operator is equal to its hermitian 168 transpose. Default is `None`, unless `base_operator` is self-adjoint 169 and `v = None` (meaning `u=v`), in which case this defaults to `True`. 170 is_positive_definite: Expect that this operator is positive definite. 171 Default is `None`, unless `base_operator` is positive-definite 172 `v = None` (meaning `u=v`), and `is_diag_update_positive`, in which case 173 this defaults to `True`. 174 Note that we say an operator is positive definite when the quadratic 175 form `x^H A x` has positive real part for all nonzero `x`. 176 is_square: Expect that this operator acts like square [batch] matrices. 177 name: A name for this `LinearOperator`. 178 179 Raises: 180 ValueError: If `is_X` flags are set in an inconsistent way. 181 """ 182 parameters = dict( 183 base_operator=base_operator, 184 u=u, 185 diag_update=diag_update, 186 v=v, 187 is_diag_update_positive=is_diag_update_positive, 188 is_non_singular=is_non_singular, 189 is_self_adjoint=is_self_adjoint, 190 is_positive_definite=is_positive_definite, 191 is_square=is_square, 192 name=name 193 ) 194 dtype = base_operator.dtype 195 196 if diag_update is not None: 197 if is_diag_update_positive and dtype.is_complex: 198 logging.warn("Note: setting is_diag_update_positive with a complex " 199 "dtype means that diagonal is real and positive.") 200 201 if diag_update is None: 202 if is_diag_update_positive is False: 203 raise ValueError( 204 "Default diagonal is the identity, which is positive. However, " 205 "user set 'is_diag_update_positive' to False.") 206 is_diag_update_positive = True 207 208 # In this case, we can use a Cholesky decomposition to help us solve/det. 209 self._use_cholesky = ( 210 base_operator.is_positive_definite and base_operator.is_self_adjoint 211 and is_diag_update_positive 212 and v is None) 213 214 # Possibly auto-set some characteristic flags from None to True. 215 # If the Flags were set (by the user) incorrectly to False, then raise. 216 if base_operator.is_self_adjoint and v is None and not dtype.is_complex: 217 if is_self_adjoint is False: 218 raise ValueError( 219 "A = L + UDU^H, with L self-adjoint and D real diagonal. Since" 220 " UDU^H is self-adjoint, this must be a self-adjoint operator.") 221 is_self_adjoint = True 222 223 # The condition for using a cholesky is sufficient for SPD, and 224 # we no weaker choice of these hints leads to SPD. Therefore, 225 # the following line reads "if hints indicate SPD..." 226 if self._use_cholesky: 227 if ( 228 is_positive_definite is False 229 or is_self_adjoint is False 230 or is_non_singular is False): 231 raise ValueError( 232 "Arguments imply this is self-adjoint positive-definite operator.") 233 is_positive_definite = True 234 is_self_adjoint = True 235 236 with ops.name_scope(name): 237 238 # Create U and V. 239 self._u = linear_operator_util.convert_nonref_to_tensor(u, name="u") 240 if v is None: 241 self._v = self._u 242 else: 243 self._v = linear_operator_util.convert_nonref_to_tensor(v, name="v") 244 245 if diag_update is None: 246 self._diag_update = None 247 else: 248 self._diag_update = linear_operator_util.convert_nonref_to_tensor( 249 diag_update, name="diag_update") 250 251 # Create base_operator L. 252 self._base_operator = base_operator 253 254 super(LinearOperatorLowRankUpdate, self).__init__( 255 dtype=self._base_operator.dtype, 256 is_non_singular=is_non_singular, 257 is_self_adjoint=is_self_adjoint, 258 is_positive_definite=is_positive_definite, 259 is_square=is_square, 260 parameters=parameters, 261 name=name) 262 263 # Create the diagonal operator D. 264 self._set_diag_operators(diag_update, is_diag_update_positive) 265 self._is_diag_update_positive = is_diag_update_positive 266 267 self._check_shapes() 268 269 def _check_shapes(self): 270 """Static check that shapes are compatible.""" 271 # Broadcast shape also checks that u and v are compatible. 272 uv_shape = array_ops.broadcast_static_shape( 273 self.u.shape, self.v.shape) 274 275 batch_shape = array_ops.broadcast_static_shape( 276 self.base_operator.batch_shape, uv_shape[:-2]) 277 278 tensor_shape.Dimension( 279 self.base_operator.domain_dimension).assert_is_compatible_with( 280 uv_shape[-2]) 281 282 if self._diag_update is not None: 283 tensor_shape.dimension_at_index(uv_shape, -1).assert_is_compatible_with( 284 self._diag_update.shape[-1]) 285 array_ops.broadcast_static_shape( 286 batch_shape, self._diag_update.shape[:-1]) 287 288 def _set_diag_operators(self, diag_update, is_diag_update_positive): 289 """Set attributes self._diag_update and self._diag_operator.""" 290 if diag_update is not None: 291 self._diag_operator = linear_operator_diag.LinearOperatorDiag( 292 self._diag_update, is_positive_definite=is_diag_update_positive) 293 else: 294 if tensor_shape.dimension_value(self.u.shape[-1]) is not None: 295 r = tensor_shape.dimension_value(self.u.shape[-1]) 296 else: 297 r = array_ops.shape(self.u)[-1] 298 self._diag_operator = linear_operator_identity.LinearOperatorIdentity( 299 num_rows=r, dtype=self.dtype) 300 301 @property 302 def u(self): 303 """If this operator is `A = L + U D V^H`, this is the `U`.""" 304 return self._u 305 306 @property 307 def v(self): 308 """If this operator is `A = L + U D V^H`, this is the `V`.""" 309 return self._v 310 311 @property 312 def is_diag_update_positive(self): 313 """If this operator is `A = L + U D V^H`, this hints `D > 0` elementwise.""" 314 return self._is_diag_update_positive 315 316 @property 317 def diag_update(self): 318 """If this operator is `A = L + U D V^H`, this is the diagonal of `D`.""" 319 return self._diag_update 320 321 @property 322 def diag_operator(self): 323 """If this operator is `A = L + U D V^H`, this is `D`.""" 324 return self._diag_operator 325 326 @property 327 def base_operator(self): 328 """If this operator is `A = L + U D V^H`, this is the `L`.""" 329 return self._base_operator 330 331 def _assert_self_adjoint(self): 332 # Recall this operator is: 333 # A = L + UDV^H. 334 # So in one case self-adjoint depends only on L 335 if self.u is self.v and self.diag_update is None: 336 return self.base_operator.assert_self_adjoint() 337 # In all other cases, sufficient conditions for self-adjoint can be found 338 # efficiently. However, those conditions are not necessary conditions. 339 return super(LinearOperatorLowRankUpdate, self).assert_self_adjoint() 340 341 def _shape(self): 342 batch_shape = array_ops.broadcast_static_shape( 343 self.base_operator.batch_shape, 344 self.diag_operator.batch_shape) 345 batch_shape = array_ops.broadcast_static_shape( 346 batch_shape, 347 self.u.shape[:-2]) 348 batch_shape = array_ops.broadcast_static_shape( 349 batch_shape, 350 self.v.shape[:-2]) 351 return batch_shape.concatenate(self.base_operator.shape[-2:]) 352 353 def _shape_tensor(self): 354 batch_shape = array_ops.broadcast_dynamic_shape( 355 self.base_operator.batch_shape_tensor(), 356 self.diag_operator.batch_shape_tensor()) 357 batch_shape = array_ops.broadcast_dynamic_shape( 358 batch_shape, 359 array_ops.shape(self.u)[:-2]) 360 batch_shape = array_ops.broadcast_dynamic_shape( 361 batch_shape, 362 array_ops.shape(self.v)[:-2]) 363 return array_ops.concat( 364 [batch_shape, self.base_operator.shape_tensor()[-2:]], axis=0) 365 366 def _get_uv_as_tensors(self): 367 """Get (self.u, self.v) as tensors (in case they were refs).""" 368 u = ops.convert_to_tensor_v2_with_dispatch(self.u) 369 if self.v is self.u: 370 v = u 371 else: 372 v = ops.convert_to_tensor_v2_with_dispatch(self.v) 373 return u, v 374 375 def _matmul(self, x, adjoint=False, adjoint_arg=False): 376 u, v = self._get_uv_as_tensors() 377 l = self.base_operator 378 d = self.diag_operator 379 380 leading_term = l.matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg) 381 382 if adjoint: 383 uh_x = math_ops.matmul(u, x, adjoint_a=True, adjoint_b=adjoint_arg) 384 d_uh_x = d.matmul(uh_x, adjoint=adjoint) 385 v_d_uh_x = math_ops.matmul(v, d_uh_x) 386 return leading_term + v_d_uh_x 387 else: 388 vh_x = math_ops.matmul(v, x, adjoint_a=True, adjoint_b=adjoint_arg) 389 d_vh_x = d.matmul(vh_x, adjoint=adjoint) 390 u_d_vh_x = math_ops.matmul(u, d_vh_x) 391 return leading_term + u_d_vh_x 392 393 def _determinant(self): 394 if self.is_positive_definite: 395 return math_ops.exp(self.log_abs_determinant()) 396 # The matrix determinant lemma gives 397 # https://en.wikipedia.org/wiki/Matrix_determinant_lemma 398 # det(L + UDV^H) = det(D^{-1} + V^H L^{-1} U) det(D) det(L) 399 # = det(C) det(D) det(L) 400 # where C is sometimes known as the capacitance matrix, 401 # C := D^{-1} + V^H L^{-1} U 402 u, v = self._get_uv_as_tensors() 403 det_c = linalg_ops.matrix_determinant(self._make_capacitance(u=u, v=v)) 404 det_d = self.diag_operator.determinant() 405 det_l = self.base_operator.determinant() 406 return det_c * det_d * det_l 407 408 def _diag_part(self): 409 # [U D V^T]_{ii} = sum_{jk} U_{ij} D_{jk} V_{ik} 410 # = sum_{j} U_{ij} D_{jj} V_{ij} 411 u, v = self._get_uv_as_tensors() 412 product = u * math_ops.conj(v) 413 if self.diag_update is not None: 414 product *= array_ops.expand_dims(self.diag_update, axis=-2) 415 return ( 416 math_ops.reduce_sum(product, axis=-1) + self.base_operator.diag_part()) 417 418 def _log_abs_determinant(self): 419 u, v = self._get_uv_as_tensors() 420 # Recall 421 # det(L + UDV^H) = det(D^{-1} + V^H L^{-1} U) det(D) det(L) 422 # = det(C) det(D) det(L) 423 log_abs_det_d = self.diag_operator.log_abs_determinant() 424 log_abs_det_l = self.base_operator.log_abs_determinant() 425 426 if self._use_cholesky: 427 chol_cap_diag = array_ops.matrix_diag_part( 428 linalg_ops.cholesky(self._make_capacitance(u=u, v=v))) 429 log_abs_det_c = 2 * math_ops.reduce_sum( 430 math_ops.log(chol_cap_diag), axis=[-1]) 431 else: 432 det_c = linalg_ops.matrix_determinant(self._make_capacitance(u=u, v=v)) 433 log_abs_det_c = math_ops.log(math_ops.abs(det_c)) 434 if self.dtype.is_complex: 435 log_abs_det_c = math_ops.cast(log_abs_det_c, dtype=self.dtype) 436 437 return log_abs_det_c + log_abs_det_d + log_abs_det_l 438 439 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 440 if self.base_operator.is_non_singular is False: 441 raise ValueError( 442 "Solve not implemented unless this is a perturbation of a " 443 "non-singular LinearOperator.") 444 # The Woodbury formula gives: 445 # https://en.wikipedia.org/wiki/Woodbury_matrix_identity 446 # (L + UDV^H)^{-1} 447 # = L^{-1} - L^{-1} U (D^{-1} + V^H L^{-1} U)^{-1} V^H L^{-1} 448 # = L^{-1} - L^{-1} U C^{-1} V^H L^{-1} 449 # where C is the capacitance matrix, C := D^{-1} + V^H L^{-1} U 450 # Note also that, with ^{-H} being the inverse of the adjoint, 451 # (L + UDV^H)^{-H} 452 # = L^{-H} - L^{-H} V C^{-H} U^H L^{-H} 453 l = self.base_operator 454 if adjoint: 455 # If adjoint, U and V have flipped roles in the operator. 456 v, u = self._get_uv_as_tensors() 457 # Capacitance should still be computed with u=self.u and v=self.v, which 458 # after the "flip" on the line above means u=v, v=u. I.e. no need to 459 # "flip" in the capacitance call, since the call to 460 # matrix_solve_with_broadcast below is done with the `adjoint` argument, 461 # and this takes care of things. 462 capacitance = self._make_capacitance(u=v, v=u) 463 else: 464 u, v = self._get_uv_as_tensors() 465 capacitance = self._make_capacitance(u=u, v=v) 466 467 # L^{-1} rhs 468 linv_rhs = l.solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg) 469 # V^H L^{-1} rhs 470 vh_linv_rhs = math_ops.matmul(v, linv_rhs, adjoint_a=True) 471 # C^{-1} V^H L^{-1} rhs 472 if self._use_cholesky: 473 capinv_vh_linv_rhs = linalg_ops.cholesky_solve( 474 linalg_ops.cholesky(capacitance), vh_linv_rhs) 475 else: 476 capinv_vh_linv_rhs = linear_operator_util.matrix_solve_with_broadcast( 477 capacitance, vh_linv_rhs, adjoint=adjoint) 478 # U C^{-1} V^H M^{-1} rhs 479 u_capinv_vh_linv_rhs = math_ops.matmul(u, capinv_vh_linv_rhs) 480 # L^{-1} U C^{-1} V^H L^{-1} rhs 481 linv_u_capinv_vh_linv_rhs = l.solve(u_capinv_vh_linv_rhs, adjoint=adjoint) 482 483 # L^{-1} - L^{-1} U C^{-1} V^H L^{-1} 484 return linv_rhs - linv_u_capinv_vh_linv_rhs 485 486 def _make_capacitance(self, u, v): 487 # C := D^{-1} + V^H L^{-1} U 488 # which is sometimes known as the "capacitance" matrix. 489 490 # L^{-1} U 491 linv_u = self.base_operator.solve(u) 492 # V^H L^{-1} U 493 vh_linv_u = math_ops.matmul(v, linv_u, adjoint_a=True) 494 495 # D^{-1} + V^H L^{-1} V 496 capacitance = self._diag_operator.inverse().add_to_tensor(vh_linv_u) 497 return capacitance 498 499 @property 500 def _composite_tensor_fields(self): 501 return ("base_operator", "u", "diag_update", "v", "is_diag_update_positive") 502 503 @property 504 def _experimental_parameter_ndims_to_matrix_ndims(self): 505 return { 506 "base_operator": 0, 507 "u": 2, 508 "diag_update": 1, 509 "v": 2 510 } 511