• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Perturb a `LinearOperator` with a rank `K` update."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import tensor_shape
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import linalg_ops
25from tensorflow.python.ops import math_ops
26from tensorflow.python.ops.linalg import linear_operator
27from tensorflow.python.ops.linalg import linear_operator_diag
28from tensorflow.python.ops.linalg import linear_operator_identity
29from tensorflow.python.ops.linalg import linear_operator_util
30from tensorflow.python.platform import tf_logging as logging
31from tensorflow.python.util.tf_export import tf_export
32
33__all__ = [
34    "LinearOperatorLowRankUpdate",
35]
36
37
38@tf_export("linalg.LinearOperatorLowRankUpdate")
39class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
40  """Perturb a `LinearOperator` with a rank `K` update.
41
42  This operator acts like a [batch] matrix `A` with shape
43  `[B1,...,Bb, M, N]` for some `b >= 0`.  The first `b` indices index a
44  batch member.  For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
45  an `M x N` matrix.
46
47  `LinearOperatorLowRankUpdate` represents `A = L + U D V^H`, where
48
49  ```
50  L, is a LinearOperator representing [batch] M x N matrices
51  U, is a [batch] M x K matrix.  Typically K << M.
52  D, is a [batch] K x K matrix.
53  V, is a [batch] N x K matrix.  Typically K << N.
54  V^H is the Hermitian transpose (adjoint) of V.
55  ```
56
57  If `M = N`, determinants and solves are done using the matrix determinant
58  lemma and Woodbury identities, and thus require L and D to be non-singular.
59
60  Solves and determinants will be attempted unless the "is_non_singular"
61  property of L and D is False.
62
63  In the event that L and D are positive-definite, and U = V, solves and
64  determinants can be done using a Cholesky factorization.
65
66  ```python
67  # Create a 3 x 3 diagonal linear operator.
68  diag_operator = LinearOperatorDiag(
69      diag_update=[1., 2., 3.], is_non_singular=True, is_self_adjoint=True,
70      is_positive_definite=True)
71
72  # Perturb with a rank 2 perturbation
73  operator = LinearOperatorLowRankUpdate(
74      operator=diag_operator,
75      u=[[1., 2.], [-1., 3.], [0., 0.]],
76      diag_update=[11., 12.],
77      v=[[1., 2.], [-1., 3.], [10., 10.]])
78
79  operator.shape
80  ==> [3, 3]
81
82  operator.log_abs_determinant()
83  ==> scalar Tensor
84
85  x = ... Shape [3, 4] Tensor
86  operator.matmul(x)
87  ==> Shape [3, 4] Tensor
88  ```
89
90  ### Shape compatibility
91
92  This operator acts on [batch] matrix with compatible shape.
93  `x` is a batch matrix with compatible shape for `matmul` and `solve` if
94
95  ```
96  operator.shape = [B1,...,Bb] + [M, N],  with b >= 0
97  x.shape =        [B1,...,Bb] + [N, R],  with R >= 0.
98  ```
99
100  ### Performance
101
102  Suppose `operator` is a `LinearOperatorLowRankUpdate` of shape `[M, N]`,
103  made from a rank `K` update of `base_operator` which performs `.matmul(x)` on
104  `x` having `x.shape = [N, R]` with `O(L_matmul*N*R)` complexity (and similarly
105  for `solve`, `determinant`.  Then, if `x.shape = [N, R]`,
106
107  * `operator.matmul(x)` is `O(L_matmul*N*R + K*N*R)`
108
109  and if `M = N`,
110
111  * `operator.solve(x)` is `O(L_matmul*N*R + N*K*R + K^2*R + K^3)`
112  * `operator.determinant()` is `O(L_determinant + L_solve*N*K + K^2*N + K^3)`
113
114  If instead `operator` and `x` have shape `[B1,...,Bb, M, N]` and
115  `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`.
116
117  #### Matrix property hints
118
119  This `LinearOperator` is initialized with boolean flags of the form `is_X`,
120  for `X = non_singular`, `self_adjoint`, `positive_definite`,
121  `diag_update_positive` and `square`. These have the following meaning:
122
123  * If `is_X == True`, callers should expect the operator to have the
124    property `X`.  This is a promise that should be fulfilled, but is *not* a
125    runtime assert.  For example, finite floating point precision may result
126    in these promises being violated.
127  * If `is_X == False`, callers should expect the operator to not have `X`.
128  * If `is_X == None` (the default), callers should have no expectation either
129    way.
130  """
131
132  def __init__(self,
133               base_operator,
134               u,
135               diag_update=None,
136               v=None,
137               is_diag_update_positive=None,
138               is_non_singular=None,
139               is_self_adjoint=None,
140               is_positive_definite=None,
141               is_square=None,
142               name="LinearOperatorLowRankUpdate"):
143    """Initialize a `LinearOperatorLowRankUpdate`.
144
145    This creates a `LinearOperator` of the form `A = L + U D V^H`, with
146    `L` a `LinearOperator`, `U, V` both [batch] matrices, and `D` a [batch]
147    diagonal matrix.
148
149    If `L` is non-singular, solves and determinants are available.
150    Solves/determinants both involve a solve/determinant of a `K x K` system.
151    In the event that L and D are self-adjoint positive-definite, and U = V,
152    this can be done using a Cholesky factorization.  The user should set the
153    `is_X` matrix property hints, which will trigger the appropriate code path.
154
155    Args:
156      base_operator:  Shape `[B1,...,Bb, M, N]`.
157      u:  Shape `[B1,...,Bb, M, K]` `Tensor` of same `dtype` as `base_operator`.
158        This is `U` above.
159      diag_update:  Optional shape `[B1,...,Bb, K]` `Tensor` with same `dtype`
160        as `base_operator`.  This is the diagonal of `D` above.
161         Defaults to `D` being the identity operator.
162      v:  Optional `Tensor` of same `dtype` as `u` and shape `[B1,...,Bb, N, K]`
163         Defaults to `v = u`, in which case the perturbation is symmetric.
164         If `M != N`, then `v` must be set since the perturbation is not square.
165      is_diag_update_positive:  Python `bool`.
166        If `True`, expect `diag_update > 0`.
167      is_non_singular:  Expect that this operator is non-singular.
168        Default is `None`, unless `is_positive_definite` is auto-set to be
169        `True` (see below).
170      is_self_adjoint:  Expect that this operator is equal to its hermitian
171        transpose.  Default is `None`, unless `base_operator` is self-adjoint
172        and `v = None` (meaning `u=v`), in which case this defaults to `True`.
173      is_positive_definite:  Expect that this operator is positive definite.
174        Default is `None`, unless `base_operator` is positive-definite
175        `v = None` (meaning `u=v`), and `is_diag_update_positive`, in which case
176        this defaults to `True`.
177        Note that we say an operator is positive definite when the quadratic
178        form `x^H A x` has positive real part for all nonzero `x`.
179      is_square:  Expect that this operator acts like square [batch] matrices.
180      name: A name for this `LinearOperator`.
181
182    Raises:
183      ValueError:  If `is_X` flags are set in an inconsistent way.
184    """
185    dtype = base_operator.dtype
186
187    if diag_update is not None:
188      if is_diag_update_positive and dtype.is_complex:
189        logging.warn("Note: setting is_diag_update_positive with a complex "
190                     "dtype means that diagonal is real and positive.")
191
192    if diag_update is None:
193      if is_diag_update_positive is False:
194        raise ValueError(
195            "Default diagonal is the identity, which is positive.  However, "
196            "user set 'is_diag_update_positive' to False.")
197      is_diag_update_positive = True
198
199    # In this case, we can use a Cholesky decomposition to help us solve/det.
200    self._use_cholesky = (
201        base_operator.is_positive_definite and base_operator.is_self_adjoint
202        and is_diag_update_positive
203        and v is None)
204
205    # Possibly auto-set some characteristic flags from None to True.
206    # If the Flags were set (by the user) incorrectly to False, then raise.
207    if base_operator.is_self_adjoint and v is None and not dtype.is_complex:
208      if is_self_adjoint is False:
209        raise ValueError(
210            "A = L + UDU^H, with L self-adjoint and D real diagonal.  Since"
211            " UDU^H is self-adjoint, this must be a self-adjoint operator.")
212      is_self_adjoint = True
213
214    # The condition for using a cholesky is sufficient for SPD, and
215    # we no weaker choice of these hints leads to SPD.  Therefore,
216    # the following line reads "if hints indicate SPD..."
217    if self._use_cholesky:
218      if (
219          is_positive_definite is False
220          or is_self_adjoint is False
221          or is_non_singular is False):
222        raise ValueError(
223            "Arguments imply this is self-adjoint positive-definite operator.")
224      is_positive_definite = True
225      is_self_adjoint = True
226
227    values = base_operator.graph_parents + [u, diag_update, v]
228    with ops.name_scope(name, values=values):
229
230      # Create U and V.
231      self._u = ops.convert_to_tensor(u, name="u")
232      if v is None:
233        self._v = self._u
234      else:
235        self._v = ops.convert_to_tensor(v, name="v")
236
237      if diag_update is None:
238        self._diag_update = None
239      else:
240        self._diag_update = ops.convert_to_tensor(
241            diag_update, name="diag_update")
242
243      # Create base_operator L.
244      self._base_operator = base_operator
245      graph_parents = base_operator.graph_parents + [
246          self.u, self._diag_update, self.v]
247      graph_parents = [p for p in graph_parents if p is not None]
248
249      super(LinearOperatorLowRankUpdate, self).__init__(
250          dtype=self._base_operator.dtype,
251          graph_parents=graph_parents,
252          is_non_singular=is_non_singular,
253          is_self_adjoint=is_self_adjoint,
254          is_positive_definite=is_positive_definite,
255          is_square=is_square,
256          name=name)
257
258      # Create the diagonal operator D.
259      self._set_diag_operators(diag_update, is_diag_update_positive)
260      self._is_diag_update_positive = is_diag_update_positive
261
262      self._check_shapes()
263
264      # Pre-compute the so-called "capacitance" matrix
265      #   C := D^{-1} + V^H L^{-1} U
266      self._capacitance = self._make_capacitance()
267      if self._use_cholesky:
268        self._chol_capacitance = linalg_ops.cholesky(self._capacitance)
269
270  def _check_shapes(self):
271    """Static check that shapes are compatible."""
272    # Broadcast shape also checks that u and v are compatible.
273    uv_shape = array_ops.broadcast_static_shape(
274        self.u.get_shape(), self.v.get_shape())
275
276    batch_shape = array_ops.broadcast_static_shape(
277        self.base_operator.batch_shape, uv_shape[:-2])
278
279    tensor_shape.Dimension(
280        self.base_operator.domain_dimension).assert_is_compatible_with(
281            uv_shape[-2])
282
283    if self._diag_update is not None:
284      tensor_shape.dimension_at_index(uv_shape, -1).assert_is_compatible_with(
285          self._diag_update.get_shape()[-1])
286      array_ops.broadcast_static_shape(
287          batch_shape, self._diag_update.get_shape()[:-1])
288
289  def _set_diag_operators(self, diag_update, is_diag_update_positive):
290    """Set attributes self._diag_update and self._diag_operator."""
291    if diag_update is not None:
292      self._diag_operator = linear_operator_diag.LinearOperatorDiag(
293          self._diag_update, is_positive_definite=is_diag_update_positive)
294      self._diag_inv_operator = linear_operator_diag.LinearOperatorDiag(
295          1. / self._diag_update, is_positive_definite=is_diag_update_positive)
296    else:
297      if tensor_shape.dimension_value(self.u.shape[-1]) is not None:
298        r = tensor_shape.dimension_value(self.u.shape[-1])
299      else:
300        r = array_ops.shape(self.u)[-1]
301      self._diag_operator = linear_operator_identity.LinearOperatorIdentity(
302          num_rows=r, dtype=self.dtype)
303      self._diag_inv_operator = self._diag_operator
304
305  @property
306  def u(self):
307    """If this operator is `A = L + U D V^H`, this is the `U`."""
308    return self._u
309
310  @property
311  def v(self):
312    """If this operator is `A = L + U D V^H`, this is the `V`."""
313    return self._v
314
315  @property
316  def is_diag_update_positive(self):
317    """If this operator is `A = L + U D V^H`, this hints `D > 0` elementwise."""
318    return self._is_diag_update_positive
319
320  @property
321  def diag_update(self):
322    """If this operator is `A = L + U D V^H`, this is the diagonal of `D`."""
323    return self._diag_update
324
325  @property
326  def diag_operator(self):
327    """If this operator is `A = L + U D V^H`, this is `D`."""
328    return self._diag_operator
329
330  @property
331  def base_operator(self):
332    """If this operator is `A = L + U D V^H`, this is the `L`."""
333    return self._base_operator
334
335  def _shape(self):
336    batch_shape = array_ops.broadcast_static_shape(
337        self.base_operator.batch_shape,
338        self.u.get_shape()[:-2])
339    return batch_shape.concatenate(self.base_operator.shape[-2:])
340
341  def _shape_tensor(self):
342    batch_shape = array_ops.broadcast_dynamic_shape(
343        self.base_operator.batch_shape_tensor(),
344        array_ops.shape(self.u)[:-2])
345    return array_ops.concat(
346        [batch_shape, self.base_operator.shape_tensor()[-2:]], axis=0)
347
348  def _matmul(self, x, adjoint=False, adjoint_arg=False):
349    u = self.u
350    v = self.v
351    l = self.base_operator
352    d = self.diag_operator
353
354    leading_term = l.matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
355
356    if adjoint:
357      uh_x = linear_operator_util.matmul_with_broadcast(
358          u, x, adjoint_a=True, adjoint_b=adjoint_arg)
359      d_uh_x = d.matmul(uh_x, adjoint=adjoint)
360      v_d_uh_x = linear_operator_util.matmul_with_broadcast(
361          v, d_uh_x)
362      return leading_term + v_d_uh_x
363    else:
364      vh_x = linear_operator_util.matmul_with_broadcast(
365          v, x, adjoint_a=True, adjoint_b=adjoint_arg)
366      d_vh_x = d.matmul(vh_x, adjoint=adjoint)
367      u_d_vh_x = linear_operator_util.matmul_with_broadcast(u, d_vh_x)
368      return leading_term + u_d_vh_x
369
370  def _determinant(self):
371    if self.is_positive_definite:
372      return math_ops.exp(self.log_abs_determinant())
373    # The matrix determinant lemma gives
374    # https://en.wikipedia.org/wiki/Matrix_determinant_lemma
375    #   det(L + UDV^H) = det(D^{-1} + V^H L^{-1} U) det(D) det(L)
376    #                  = det(C) det(D) det(L)
377    # where C is sometimes known as the capacitance matrix,
378    #   C := D^{-1} + V^H L^{-1} U
379    det_c = linalg_ops.matrix_determinant(self._capacitance)
380    det_d = self.diag_operator.determinant()
381    det_l = self.base_operator.determinant()
382    return det_c * det_d * det_l
383
384  def _log_abs_determinant(self):
385    # Recall
386    #   det(L + UDV^H) = det(D^{-1} + V^H L^{-1} U) det(D) det(L)
387    #                  = det(C) det(D) det(L)
388    log_abs_det_d = self.diag_operator.log_abs_determinant()
389    log_abs_det_l = self.base_operator.log_abs_determinant()
390
391    if self._use_cholesky:
392      chol_cap_diag = array_ops.matrix_diag_part(self._chol_capacitance)
393      log_abs_det_c = 2 * math_ops.reduce_sum(
394          math_ops.log(chol_cap_diag), axis=[-1])
395    else:
396      det_c = linalg_ops.matrix_determinant(self._capacitance)
397      log_abs_det_c = math_ops.log(math_ops.abs(det_c))
398      if self.dtype.is_complex:
399        log_abs_det_c = math_ops.cast(log_abs_det_c, dtype=self.dtype)
400
401    return log_abs_det_c + log_abs_det_d + log_abs_det_l
402
403  def _solve(self, rhs, adjoint=False, adjoint_arg=False):
404    if self.base_operator.is_non_singular is False:
405      raise ValueError(
406          "Solve not implemented unless this is a perturbation of a "
407          "non-singular LinearOperator.")
408    # The Woodbury formula gives:
409    # https://en.wikipedia.org/wiki/Woodbury_matrix_identity
410    #   (L + UDV^H)^{-1}
411    #   = L^{-1} - L^{-1} U (D^{-1} + V^H L^{-1} U)^{-1} V^H L^{-1}
412    #   = L^{-1} - L^{-1} U C^{-1} V^H L^{-1}
413    # where C is the capacitance matrix, C := D^{-1} + V^H L^{-1} U
414    # Note also that, with ^{-H} being the inverse of the adjoint,
415    #   (L + UDV^H)^{-H}
416    #   = L^{-H} - L^{-H} V C^{-H} U^H L^{-H}
417    l = self.base_operator
418    if adjoint:
419      v = self.u
420      u = self.v
421    else:
422      v = self.v
423      u = self.u
424
425    # L^{-1} rhs
426    linv_rhs = l.solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
427    # V^H L^{-1} rhs
428    vh_linv_rhs = linear_operator_util.matmul_with_broadcast(
429        v, linv_rhs, adjoint_a=True)
430    # C^{-1} V^H L^{-1} rhs
431    if self._use_cholesky:
432      capinv_vh_linv_rhs = linear_operator_util.cholesky_solve_with_broadcast(
433          self._chol_capacitance, vh_linv_rhs)
434    else:
435      capinv_vh_linv_rhs = linear_operator_util.matrix_solve_with_broadcast(
436          self._capacitance, vh_linv_rhs, adjoint=adjoint)
437    # U C^{-1} V^H M^{-1} rhs
438    u_capinv_vh_linv_rhs = linear_operator_util.matmul_with_broadcast(
439        u, capinv_vh_linv_rhs)
440    # L^{-1} U C^{-1} V^H L^{-1} rhs
441    linv_u_capinv_vh_linv_rhs = l.solve(u_capinv_vh_linv_rhs, adjoint=adjoint)
442
443    # L^{-1} - L^{-1} U C^{-1} V^H L^{-1}
444    return linv_rhs - linv_u_capinv_vh_linv_rhs
445
446  def _make_capacitance(self):
447    # C := D^{-1} + V^H L^{-1} U
448    # which is sometimes known as the "capacitance" matrix.
449
450    # L^{-1} U
451    linv_u = self.base_operator.solve(self.u)
452    # V^H L^{-1} U
453    vh_linv_u = linear_operator_util.matmul_with_broadcast(
454        self.v, linv_u, adjoint_a=True)
455
456    # D^{-1} + V^H L^{-1} V
457    capacitance = self._diag_inv_operator.add_to_tensor(vh_linv_u)
458    return capacitance
459