• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Perturb a `LinearOperator` with a rank `K` update."""
16
17from tensorflow.python.framework import ops
18from tensorflow.python.framework import tensor_shape
19from tensorflow.python.ops import array_ops
20from tensorflow.python.ops import linalg_ops
21from tensorflow.python.ops import math_ops
22from tensorflow.python.ops.linalg import linear_operator
23from tensorflow.python.ops.linalg import linear_operator_diag
24from tensorflow.python.ops.linalg import linear_operator_identity
25from tensorflow.python.ops.linalg import linear_operator_util
26from tensorflow.python.platform import tf_logging as logging
27from tensorflow.python.util.tf_export import tf_export
28
29__all__ = [
30    "LinearOperatorLowRankUpdate",
31]
32
33
34@tf_export("linalg.LinearOperatorLowRankUpdate")
35@linear_operator.make_composite_tensor
36class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
37  """Perturb a `LinearOperator` with a rank `K` update.
38
39  This operator acts like a [batch] matrix `A` with shape
40  `[B1,...,Bb, M, N]` for some `b >= 0`.  The first `b` indices index a
41  batch member.  For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
42  an `M x N` matrix.
43
44  `LinearOperatorLowRankUpdate` represents `A = L + U D V^H`, where
45
46  ```
47  L, is a LinearOperator representing [batch] M x N matrices
48  U, is a [batch] M x K matrix.  Typically K << M.
49  D, is a [batch] K x K matrix.
50  V, is a [batch] N x K matrix.  Typically K << N.
51  V^H is the Hermitian transpose (adjoint) of V.
52  ```
53
54  If `M = N`, determinants and solves are done using the matrix determinant
55  lemma and Woodbury identities, and thus require L and D to be non-singular.
56
57  Solves and determinants will be attempted unless the "is_non_singular"
58  property of L and D is False.
59
60  In the event that L and D are positive-definite, and U = V, solves and
61  determinants can be done using a Cholesky factorization.
62
63  ```python
64  # Create a 3 x 3 diagonal linear operator.
65  diag_operator = LinearOperatorDiag(
66      diag_update=[1., 2., 3.], is_non_singular=True, is_self_adjoint=True,
67      is_positive_definite=True)
68
69  # Perturb with a rank 2 perturbation
70  operator = LinearOperatorLowRankUpdate(
71      operator=diag_operator,
72      u=[[1., 2.], [-1., 3.], [0., 0.]],
73      diag_update=[11., 12.],
74      v=[[1., 2.], [-1., 3.], [10., 10.]])
75
76  operator.shape
77  ==> [3, 3]
78
79  operator.log_abs_determinant()
80  ==> scalar Tensor
81
82  x = ... Shape [3, 4] Tensor
83  operator.matmul(x)
84  ==> Shape [3, 4] Tensor
85  ```
86
87  ### Shape compatibility
88
89  This operator acts on [batch] matrix with compatible shape.
90  `x` is a batch matrix with compatible shape for `matmul` and `solve` if
91
92  ```
93  operator.shape = [B1,...,Bb] + [M, N],  with b >= 0
94  x.shape =        [B1,...,Bb] + [N, R],  with R >= 0.
95  ```
96
97  ### Performance
98
99  Suppose `operator` is a `LinearOperatorLowRankUpdate` of shape `[M, N]`,
100  made from a rank `K` update of `base_operator` which performs `.matmul(x)` on
101  `x` having `x.shape = [N, R]` with `O(L_matmul*N*R)` complexity (and similarly
102  for `solve`, `determinant`.  Then, if `x.shape = [N, R]`,
103
104  * `operator.matmul(x)` is `O(L_matmul*N*R + K*N*R)`
105
106  and if `M = N`,
107
108  * `operator.solve(x)` is `O(L_matmul*N*R + N*K*R + K^2*R + K^3)`
109  * `operator.determinant()` is `O(L_determinant + L_solve*N*K + K^2*N + K^3)`
110
111  If instead `operator` and `x` have shape `[B1,...,Bb, M, N]` and
112  `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`.
113
114  #### Matrix property hints
115
116  This `LinearOperator` is initialized with boolean flags of the form `is_X`,
117  for `X = non_singular`, `self_adjoint`, `positive_definite`,
118  `diag_update_positive` and `square`. These have the following meaning:
119
120  * If `is_X == True`, callers should expect the operator to have the
121    property `X`.  This is a promise that should be fulfilled, but is *not* a
122    runtime assert.  For example, finite floating point precision may result
123    in these promises being violated.
124  * If `is_X == False`, callers should expect the operator to not have `X`.
125  * If `is_X == None` (the default), callers should have no expectation either
126    way.
127  """
128
129  def __init__(self,
130               base_operator,
131               u,
132               diag_update=None,
133               v=None,
134               is_diag_update_positive=None,
135               is_non_singular=None,
136               is_self_adjoint=None,
137               is_positive_definite=None,
138               is_square=None,
139               name="LinearOperatorLowRankUpdate"):
140    """Initialize a `LinearOperatorLowRankUpdate`.
141
142    This creates a `LinearOperator` of the form `A = L + U D V^H`, with
143    `L` a `LinearOperator`, `U, V` both [batch] matrices, and `D` a [batch]
144    diagonal matrix.
145
146    If `L` is non-singular, solves and determinants are available.
147    Solves/determinants both involve a solve/determinant of a `K x K` system.
148    In the event that L and D are self-adjoint positive-definite, and U = V,
149    this can be done using a Cholesky factorization.  The user should set the
150    `is_X` matrix property hints, which will trigger the appropriate code path.
151
152    Args:
153      base_operator:  Shape `[B1,...,Bb, M, N]`.
154      u:  Shape `[B1,...,Bb, M, K]` `Tensor` of same `dtype` as `base_operator`.
155        This is `U` above.
156      diag_update:  Optional shape `[B1,...,Bb, K]` `Tensor` with same `dtype`
157        as `base_operator`.  This is the diagonal of `D` above.
158         Defaults to `D` being the identity operator.
159      v:  Optional `Tensor` of same `dtype` as `u` and shape `[B1,...,Bb, N, K]`
160         Defaults to `v = u`, in which case the perturbation is symmetric.
161         If `M != N`, then `v` must be set since the perturbation is not square.
162      is_diag_update_positive:  Python `bool`.
163        If `True`, expect `diag_update > 0`.
164      is_non_singular:  Expect that this operator is non-singular.
165        Default is `None`, unless `is_positive_definite` is auto-set to be
166        `True` (see below).
167      is_self_adjoint:  Expect that this operator is equal to its hermitian
168        transpose.  Default is `None`, unless `base_operator` is self-adjoint
169        and `v = None` (meaning `u=v`), in which case this defaults to `True`.
170      is_positive_definite:  Expect that this operator is positive definite.
171        Default is `None`, unless `base_operator` is positive-definite
172        `v = None` (meaning `u=v`), and `is_diag_update_positive`, in which case
173        this defaults to `True`.
174        Note that we say an operator is positive definite when the quadratic
175        form `x^H A x` has positive real part for all nonzero `x`.
176      is_square:  Expect that this operator acts like square [batch] matrices.
177      name: A name for this `LinearOperator`.
178
179    Raises:
180      ValueError:  If `is_X` flags are set in an inconsistent way.
181    """
182    parameters = dict(
183        base_operator=base_operator,
184        u=u,
185        diag_update=diag_update,
186        v=v,
187        is_diag_update_positive=is_diag_update_positive,
188        is_non_singular=is_non_singular,
189        is_self_adjoint=is_self_adjoint,
190        is_positive_definite=is_positive_definite,
191        is_square=is_square,
192        name=name
193    )
194    dtype = base_operator.dtype
195
196    if diag_update is not None:
197      if is_diag_update_positive and dtype.is_complex:
198        logging.warn("Note: setting is_diag_update_positive with a complex "
199                     "dtype means that diagonal is real and positive.")
200
201    if diag_update is None:
202      if is_diag_update_positive is False:
203        raise ValueError(
204            "Default diagonal is the identity, which is positive.  However, "
205            "user set 'is_diag_update_positive' to False.")
206      is_diag_update_positive = True
207
208    # In this case, we can use a Cholesky decomposition to help us solve/det.
209    self._use_cholesky = (
210        base_operator.is_positive_definite and base_operator.is_self_adjoint
211        and is_diag_update_positive
212        and v is None)
213
214    # Possibly auto-set some characteristic flags from None to True.
215    # If the Flags were set (by the user) incorrectly to False, then raise.
216    if base_operator.is_self_adjoint and v is None and not dtype.is_complex:
217      if is_self_adjoint is False:
218        raise ValueError(
219            "A = L + UDU^H, with L self-adjoint and D real diagonal.  Since"
220            " UDU^H is self-adjoint, this must be a self-adjoint operator.")
221      is_self_adjoint = True
222
223    # The condition for using a cholesky is sufficient for SPD, and
224    # we no weaker choice of these hints leads to SPD.  Therefore,
225    # the following line reads "if hints indicate SPD..."
226    if self._use_cholesky:
227      if (
228          is_positive_definite is False
229          or is_self_adjoint is False
230          or is_non_singular is False):
231        raise ValueError(
232            "Arguments imply this is self-adjoint positive-definite operator.")
233      is_positive_definite = True
234      is_self_adjoint = True
235
236    with ops.name_scope(name):
237
238      # Create U and V.
239      self._u = linear_operator_util.convert_nonref_to_tensor(u, name="u")
240      if v is None:
241        self._v = self._u
242      else:
243        self._v = linear_operator_util.convert_nonref_to_tensor(v, name="v")
244
245      if diag_update is None:
246        self._diag_update = None
247      else:
248        self._diag_update = linear_operator_util.convert_nonref_to_tensor(
249            diag_update, name="diag_update")
250
251      # Create base_operator L.
252      self._base_operator = base_operator
253
254      super(LinearOperatorLowRankUpdate, self).__init__(
255          dtype=self._base_operator.dtype,
256          is_non_singular=is_non_singular,
257          is_self_adjoint=is_self_adjoint,
258          is_positive_definite=is_positive_definite,
259          is_square=is_square,
260          parameters=parameters,
261          name=name)
262
263      # Create the diagonal operator D.
264      self._set_diag_operators(diag_update, is_diag_update_positive)
265      self._is_diag_update_positive = is_diag_update_positive
266
267      self._check_shapes()
268
269  def _check_shapes(self):
270    """Static check that shapes are compatible."""
271    # Broadcast shape also checks that u and v are compatible.
272    uv_shape = array_ops.broadcast_static_shape(
273        self.u.shape, self.v.shape)
274
275    batch_shape = array_ops.broadcast_static_shape(
276        self.base_operator.batch_shape, uv_shape[:-2])
277
278    tensor_shape.Dimension(
279        self.base_operator.domain_dimension).assert_is_compatible_with(
280            uv_shape[-2])
281
282    if self._diag_update is not None:
283      tensor_shape.dimension_at_index(uv_shape, -1).assert_is_compatible_with(
284          self._diag_update.shape[-1])
285      array_ops.broadcast_static_shape(
286          batch_shape, self._diag_update.shape[:-1])
287
288  def _set_diag_operators(self, diag_update, is_diag_update_positive):
289    """Set attributes self._diag_update and self._diag_operator."""
290    if diag_update is not None:
291      self._diag_operator = linear_operator_diag.LinearOperatorDiag(
292          self._diag_update, is_positive_definite=is_diag_update_positive)
293    else:
294      if tensor_shape.dimension_value(self.u.shape[-1]) is not None:
295        r = tensor_shape.dimension_value(self.u.shape[-1])
296      else:
297        r = array_ops.shape(self.u)[-1]
298      self._diag_operator = linear_operator_identity.LinearOperatorIdentity(
299          num_rows=r, dtype=self.dtype)
300
301  @property
302  def u(self):
303    """If this operator is `A = L + U D V^H`, this is the `U`."""
304    return self._u
305
306  @property
307  def v(self):
308    """If this operator is `A = L + U D V^H`, this is the `V`."""
309    return self._v
310
311  @property
312  def is_diag_update_positive(self):
313    """If this operator is `A = L + U D V^H`, this hints `D > 0` elementwise."""
314    return self._is_diag_update_positive
315
316  @property
317  def diag_update(self):
318    """If this operator is `A = L + U D V^H`, this is the diagonal of `D`."""
319    return self._diag_update
320
321  @property
322  def diag_operator(self):
323    """If this operator is `A = L + U D V^H`, this is `D`."""
324    return self._diag_operator
325
326  @property
327  def base_operator(self):
328    """If this operator is `A = L + U D V^H`, this is the `L`."""
329    return self._base_operator
330
331  def _assert_self_adjoint(self):
332    # Recall this operator is:
333    #   A = L + UDV^H.
334    # So in one case self-adjoint depends only on L
335    if self.u is self.v and self.diag_update is None:
336      return self.base_operator.assert_self_adjoint()
337    # In all other cases, sufficient conditions for self-adjoint can be found
338    # efficiently. However, those conditions are not necessary conditions.
339    return super(LinearOperatorLowRankUpdate, self).assert_self_adjoint()
340
341  def _shape(self):
342    batch_shape = array_ops.broadcast_static_shape(
343        self.base_operator.batch_shape,
344        self.diag_operator.batch_shape)
345    batch_shape = array_ops.broadcast_static_shape(
346        batch_shape,
347        self.u.shape[:-2])
348    batch_shape = array_ops.broadcast_static_shape(
349        batch_shape,
350        self.v.shape[:-2])
351    return batch_shape.concatenate(self.base_operator.shape[-2:])
352
353  def _shape_tensor(self):
354    batch_shape = array_ops.broadcast_dynamic_shape(
355        self.base_operator.batch_shape_tensor(),
356        self.diag_operator.batch_shape_tensor())
357    batch_shape = array_ops.broadcast_dynamic_shape(
358        batch_shape,
359        array_ops.shape(self.u)[:-2])
360    batch_shape = array_ops.broadcast_dynamic_shape(
361        batch_shape,
362        array_ops.shape(self.v)[:-2])
363    return array_ops.concat(
364        [batch_shape, self.base_operator.shape_tensor()[-2:]], axis=0)
365
366  def _get_uv_as_tensors(self):
367    """Get (self.u, self.v) as tensors (in case they were refs)."""
368    u = ops.convert_to_tensor_v2_with_dispatch(self.u)
369    if self.v is self.u:
370      v = u
371    else:
372      v = ops.convert_to_tensor_v2_with_dispatch(self.v)
373    return u, v
374
375  def _matmul(self, x, adjoint=False, adjoint_arg=False):
376    u, v = self._get_uv_as_tensors()
377    l = self.base_operator
378    d = self.diag_operator
379
380    leading_term = l.matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
381
382    if adjoint:
383      uh_x = math_ops.matmul(u, x, adjoint_a=True, adjoint_b=adjoint_arg)
384      d_uh_x = d.matmul(uh_x, adjoint=adjoint)
385      v_d_uh_x = math_ops.matmul(v, d_uh_x)
386      return leading_term + v_d_uh_x
387    else:
388      vh_x = math_ops.matmul(v, x, adjoint_a=True, adjoint_b=adjoint_arg)
389      d_vh_x = d.matmul(vh_x, adjoint=adjoint)
390      u_d_vh_x = math_ops.matmul(u, d_vh_x)
391      return leading_term + u_d_vh_x
392
393  def _determinant(self):
394    if self.is_positive_definite:
395      return math_ops.exp(self.log_abs_determinant())
396    # The matrix determinant lemma gives
397    # https://en.wikipedia.org/wiki/Matrix_determinant_lemma
398    #   det(L + UDV^H) = det(D^{-1} + V^H L^{-1} U) det(D) det(L)
399    #                  = det(C) det(D) det(L)
400    # where C is sometimes known as the capacitance matrix,
401    #   C := D^{-1} + V^H L^{-1} U
402    u, v = self._get_uv_as_tensors()
403    det_c = linalg_ops.matrix_determinant(self._make_capacitance(u=u, v=v))
404    det_d = self.diag_operator.determinant()
405    det_l = self.base_operator.determinant()
406    return det_c * det_d * det_l
407
408  def _diag_part(self):
409    # [U D V^T]_{ii} = sum_{jk} U_{ij} D_{jk} V_{ik}
410    #                = sum_{j}  U_{ij} D_{jj} V_{ij}
411    u, v = self._get_uv_as_tensors()
412    product = u * math_ops.conj(v)
413    if self.diag_update is not None:
414      product *= array_ops.expand_dims(self.diag_update, axis=-2)
415    return (
416        math_ops.reduce_sum(product, axis=-1) + self.base_operator.diag_part())
417
418  def _log_abs_determinant(self):
419    u, v = self._get_uv_as_tensors()
420    # Recall
421    #   det(L + UDV^H) = det(D^{-1} + V^H L^{-1} U) det(D) det(L)
422    #                  = det(C) det(D) det(L)
423    log_abs_det_d = self.diag_operator.log_abs_determinant()
424    log_abs_det_l = self.base_operator.log_abs_determinant()
425
426    if self._use_cholesky:
427      chol_cap_diag = array_ops.matrix_diag_part(
428          linalg_ops.cholesky(self._make_capacitance(u=u, v=v)))
429      log_abs_det_c = 2 * math_ops.reduce_sum(
430          math_ops.log(chol_cap_diag), axis=[-1])
431    else:
432      det_c = linalg_ops.matrix_determinant(self._make_capacitance(u=u, v=v))
433      log_abs_det_c = math_ops.log(math_ops.abs(det_c))
434      if self.dtype.is_complex:
435        log_abs_det_c = math_ops.cast(log_abs_det_c, dtype=self.dtype)
436
437    return log_abs_det_c + log_abs_det_d + log_abs_det_l
438
439  def _solve(self, rhs, adjoint=False, adjoint_arg=False):
440    if self.base_operator.is_non_singular is False:
441      raise ValueError(
442          "Solve not implemented unless this is a perturbation of a "
443          "non-singular LinearOperator.")
444    # The Woodbury formula gives:
445    # https://en.wikipedia.org/wiki/Woodbury_matrix_identity
446    #   (L + UDV^H)^{-1}
447    #   = L^{-1} - L^{-1} U (D^{-1} + V^H L^{-1} U)^{-1} V^H L^{-1}
448    #   = L^{-1} - L^{-1} U C^{-1} V^H L^{-1}
449    # where C is the capacitance matrix, C := D^{-1} + V^H L^{-1} U
450    # Note also that, with ^{-H} being the inverse of the adjoint,
451    #   (L + UDV^H)^{-H}
452    #   = L^{-H} - L^{-H} V C^{-H} U^H L^{-H}
453    l = self.base_operator
454    if adjoint:
455      # If adjoint, U and V have flipped roles in the operator.
456      v, u = self._get_uv_as_tensors()
457      # Capacitance should still be computed with u=self.u and v=self.v, which
458      # after the "flip" on the line above means u=v, v=u. I.e. no need to
459      # "flip" in the capacitance call, since the call to
460      # matrix_solve_with_broadcast below is done with the `adjoint` argument,
461      # and this takes care of things.
462      capacitance = self._make_capacitance(u=v, v=u)
463    else:
464      u, v = self._get_uv_as_tensors()
465      capacitance = self._make_capacitance(u=u, v=v)
466
467    # L^{-1} rhs
468    linv_rhs = l.solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
469    # V^H L^{-1} rhs
470    vh_linv_rhs = math_ops.matmul(v, linv_rhs, adjoint_a=True)
471    # C^{-1} V^H L^{-1} rhs
472    if self._use_cholesky:
473      capinv_vh_linv_rhs = linalg_ops.cholesky_solve(
474          linalg_ops.cholesky(capacitance), vh_linv_rhs)
475    else:
476      capinv_vh_linv_rhs = linear_operator_util.matrix_solve_with_broadcast(
477          capacitance, vh_linv_rhs, adjoint=adjoint)
478    # U C^{-1} V^H M^{-1} rhs
479    u_capinv_vh_linv_rhs = math_ops.matmul(u, capinv_vh_linv_rhs)
480    # L^{-1} U C^{-1} V^H L^{-1} rhs
481    linv_u_capinv_vh_linv_rhs = l.solve(u_capinv_vh_linv_rhs, adjoint=adjoint)
482
483    # L^{-1} - L^{-1} U C^{-1} V^H L^{-1}
484    return linv_rhs - linv_u_capinv_vh_linv_rhs
485
486  def _make_capacitance(self, u, v):
487    # C := D^{-1} + V^H L^{-1} U
488    # which is sometimes known as the "capacitance" matrix.
489
490    # L^{-1} U
491    linv_u = self.base_operator.solve(u)
492    # V^H L^{-1} U
493    vh_linv_u = math_ops.matmul(v, linv_u, adjoint_a=True)
494
495    # D^{-1} + V^H L^{-1} V
496    capacitance = self._diag_operator.inverse().add_to_tensor(vh_linv_u)
497    return capacitance
498
499  @property
500  def _composite_tensor_fields(self):
501    return ("base_operator", "u", "diag_update", "v", "is_diag_update_positive")
502
503  @property
504  def _experimental_parameter_ndims_to_matrix_ndims(self):
505    return {
506        "base_operator": 0,
507        "u": 2,
508        "diag_update": 1,
509        "v": 2
510    }
511