• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""`LinearOperator` acting like a diagonal matrix."""
16
17from tensorflow.python.framework import ops
18from tensorflow.python.ops import array_ops
19from tensorflow.python.ops import check_ops
20from tensorflow.python.ops import math_ops
21from tensorflow.python.ops.linalg import linalg_impl as linalg
22from tensorflow.python.ops.linalg import linear_operator
23from tensorflow.python.ops.linalg import linear_operator_util
24from tensorflow.python.util.tf_export import tf_export
25
26__all__ = ["LinearOperatorDiag",]
27
28
29@tf_export("linalg.LinearOperatorDiag")
30@linear_operator.make_composite_tensor
31class LinearOperatorDiag(linear_operator.LinearOperator):
32  """`LinearOperator` acting like a [batch] square diagonal matrix.
33
34  This operator acts like a [batch] diagonal matrix `A` with shape
35  `[B1,...,Bb, N, N]` for some `b >= 0`.  The first `b` indices index a
36  batch member.  For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
37  an `N x N` matrix.  This matrix `A` is not materialized, but for
38  purposes of broadcasting this shape will be relevant.
39
40  `LinearOperatorDiag` is initialized with a (batch) vector.
41
42  ```python
43  # Create a 2 x 2 diagonal linear operator.
44  diag = [1., -1.]
45  operator = LinearOperatorDiag(diag)
46
47  operator.to_dense()
48  ==> [[1.,  0.]
49       [0., -1.]]
50
51  operator.shape
52  ==> [2, 2]
53
54  operator.log_abs_determinant()
55  ==> scalar Tensor
56
57  x = ... Shape [2, 4] Tensor
58  operator.matmul(x)
59  ==> Shape [2, 4] Tensor
60
61  # Create a [2, 3] batch of 4 x 4 linear operators.
62  diag = tf.random.normal(shape=[2, 3, 4])
63  operator = LinearOperatorDiag(diag)
64
65  # Create a shape [2, 1, 4, 2] vector.  Note that this shape is compatible
66  # since the batch dimensions, [2, 1], are broadcast to
67  # operator.batch_shape = [2, 3].
68  y = tf.random.normal(shape=[2, 1, 4, 2])
69  x = operator.solve(y)
70  ==> operator.matmul(x) = y
71  ```
72
73  #### Shape compatibility
74
75  This operator acts on [batch] matrix with compatible shape.
76  `x` is a batch matrix with compatible shape for `matmul` and `solve` if
77
78  ```
79  operator.shape = [B1,...,Bb] + [N, N],  with b >= 0
80  x.shape =   [C1,...,Cc] + [N, R],
81  and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd]
82  ```
83
84  #### Performance
85
86  Suppose `operator` is a `LinearOperatorDiag` of shape `[N, N]`,
87  and `x.shape = [N, R]`.  Then
88
89  * `operator.matmul(x)` involves `N * R` multiplications.
90  * `operator.solve(x)` involves `N` divisions and `N * R` multiplications.
91  * `operator.determinant()` involves a size `N` `reduce_prod`.
92
93  If instead `operator` and `x` have shape `[B1,...,Bb, N, N]` and
94  `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`.
95
96  #### Matrix property hints
97
98  This `LinearOperator` is initialized with boolean flags of the form `is_X`,
99  for `X = non_singular, self_adjoint, positive_definite, square`.
100  These have the following meaning:
101
102  * If `is_X == True`, callers should expect the operator to have the
103    property `X`.  This is a promise that should be fulfilled, but is *not* a
104    runtime assert.  For example, finite floating point precision may result
105    in these promises being violated.
106  * If `is_X == False`, callers should expect the operator to not have `X`.
107  * If `is_X == None` (the default), callers should have no expectation either
108    way.
109  """
110
111  def __init__(self,
112               diag,
113               is_non_singular=None,
114               is_self_adjoint=None,
115               is_positive_definite=None,
116               is_square=None,
117               name="LinearOperatorDiag"):
118    r"""Initialize a `LinearOperatorDiag`.
119
120    Args:
121      diag:  Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`.
122        The diagonal of the operator.  Allowed dtypes: `float16`, `float32`,
123          `float64`, `complex64`, `complex128`.
124      is_non_singular:  Expect that this operator is non-singular.
125      is_self_adjoint:  Expect that this operator is equal to its hermitian
126        transpose.  If `diag.dtype` is real, this is auto-set to `True`.
127      is_positive_definite:  Expect that this operator is positive definite,
128        meaning the quadratic form `x^H A x` has positive real part for all
129        nonzero `x`.  Note that we do not require the operator to be
130        self-adjoint to be positive-definite.  See:
131        https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
132      is_square:  Expect that this operator acts like square [batch] matrices.
133      name: A name for this `LinearOperator`.
134
135    Raises:
136      TypeError:  If `diag.dtype` is not an allowed type.
137      ValueError:  If `diag.dtype` is real, and `is_self_adjoint` is not `True`.
138    """
139    parameters = dict(
140        diag=diag,
141        is_non_singular=is_non_singular,
142        is_self_adjoint=is_self_adjoint,
143        is_positive_definite=is_positive_definite,
144        is_square=is_square,
145        name=name
146    )
147
148    with ops.name_scope(name, values=[diag]):
149      self._diag = linear_operator_util.convert_nonref_to_tensor(
150          diag, name="diag")
151      self._check_diag(self._diag)
152
153      # Check and auto-set hints.
154      if not self._diag.dtype.is_complex:
155        if is_self_adjoint is False:
156          raise ValueError("A real diagonal operator is always self adjoint.")
157        else:
158          is_self_adjoint = True
159
160      if is_square is False:
161        raise ValueError("Only square diagonal operators currently supported.")
162      is_square = True
163
164      super(LinearOperatorDiag, self).__init__(
165          dtype=self._diag.dtype,
166          is_non_singular=is_non_singular,
167          is_self_adjoint=is_self_adjoint,
168          is_positive_definite=is_positive_definite,
169          is_square=is_square,
170          parameters=parameters,
171          name=name)
172
173  def _check_diag(self, diag):
174    """Static check of diag."""
175    if diag.shape.ndims is not None and diag.shape.ndims < 1:
176      raise ValueError("Argument diag must have at least 1 dimension.  "
177                       "Found: %s" % diag)
178
179  def _shape(self):
180    # If d_shape = [5, 3], we return [5, 3, 3].
181    d_shape = self._diag.shape
182    return d_shape.concatenate(d_shape[-1:])
183
184  def _shape_tensor(self):
185    d_shape = array_ops.shape(self._diag)
186    k = d_shape[-1]
187    return array_ops.concat((d_shape, [k]), 0)
188
189  @property
190  def diag(self):
191    return self._diag
192
193  def _assert_non_singular(self):
194    return linear_operator_util.assert_no_entries_with_modulus_zero(
195        self._diag,
196        message="Singular operator:  Diagonal contained zero values.")
197
198  def _assert_positive_definite(self):
199    if self.dtype.is_complex:
200      message = (
201          "Diagonal operator had diagonal entries with non-positive real part, "
202          "thus was not positive definite.")
203    else:
204      message = (
205          "Real diagonal operator had non-positive diagonal entries, "
206          "thus was not positive definite.")
207
208    return check_ops.assert_positive(
209        math_ops.real(self._diag),
210        message=message)
211
212  def _assert_self_adjoint(self):
213    return linear_operator_util.assert_zero_imag_part(
214        self._diag,
215        message=(
216            "This diagonal operator contained non-zero imaginary values.  "
217            " Thus it was not self-adjoint."))
218
219  def _matmul(self, x, adjoint=False, adjoint_arg=False):
220    diag_term = math_ops.conj(self._diag) if adjoint else self._diag
221    x = linalg.adjoint(x) if adjoint_arg else x
222    diag_mat = array_ops.expand_dims(diag_term, -1)
223    return diag_mat * x
224
225  def _matvec(self, x, adjoint=False):
226    diag_term = math_ops.conj(self._diag) if adjoint else self._diag
227    return diag_term * x
228
229  def _determinant(self):
230    return math_ops.reduce_prod(self._diag, axis=[-1])
231
232  def _log_abs_determinant(self):
233    log_det = math_ops.reduce_sum(
234        math_ops.log(math_ops.abs(self._diag)), axis=[-1])
235    if self.dtype.is_complex:
236      log_det = math_ops.cast(log_det, dtype=self.dtype)
237    return log_det
238
239  def _solve(self, rhs, adjoint=False, adjoint_arg=False):
240    diag_term = math_ops.conj(self._diag) if adjoint else self._diag
241    rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
242    inv_diag_mat = array_ops.expand_dims(1. / diag_term, -1)
243    return rhs * inv_diag_mat
244
245  def _to_dense(self):
246    return array_ops.matrix_diag(self._diag)
247
248  def _diag_part(self):
249    return self.diag
250
251  def _add_to_tensor(self, x):
252    x_diag = array_ops.matrix_diag_part(x)
253    new_diag = self._diag + x_diag
254    return array_ops.matrix_set_diag(x, new_diag)
255
256  def _eigvals(self):
257    return ops.convert_to_tensor_v2_with_dispatch(self.diag)
258
259  def _cond(self):
260    abs_diag = math_ops.abs(self.diag)
261    return (math_ops.reduce_max(abs_diag, axis=-1) /
262            math_ops.reduce_min(abs_diag, axis=-1))
263
264  @property
265  def _composite_tensor_fields(self):
266    return ("diag",)
267
268  @property
269  def _experimental_parameter_ndims_to_matrix_ndims(self):
270    return {"diag": 1}
271